Source code for qumphy.models.utils.conv_blocks

# Taken from https://github.com/thuml/Time-Series-Library

import torch
import torch.nn as nn


[docs] class Inception_Block_V1(nn.Module): def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): super(Inception_Block_V1, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.num_kernels = num_kernels kernels = [] for i in range(self.num_kernels): kernels.append( nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i) ) self.kernels = nn.ModuleList(kernels) if init_weight: self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0)
[docs] def forward(self, x): res_list = [] for i in range(self.num_kernels): res_list.append(self.kernels[i](x)) res = torch.stack(res_list, dim=-1).mean(-1) return res
[docs] class Inception_Block_V2(nn.Module): def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): super(Inception_Block_V2, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.num_kernels = num_kernels kernels = [] for i in range(self.num_kernels // 2): kernels.append( nn.Conv2d( in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1], ) ) kernels.append( nn.Conv2d( in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0], ) ) kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) self.kernels = nn.ModuleList(kernels) if init_weight: self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0)
[docs] def forward(self, x): res_list = [] for i in range(self.num_kernels // 2 * 2 + 1): res_list.append(self.kernels[i](x)) res = torch.stack(res_list, dim=-1).mean(-1) return res