"""
File: qumphy/models/inception_mantas.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Inception model, adapted from Matlab code from KTU.
"""
import torch
import torch.nn as nn
[docs]
class Inception1DBlock(nn.Module):
"""
Single Inception block
Branch 1: 1x1 conv
Branch 2: 1x1 conv -> 3x1 conv
Branch 3: 1x1 conv -> 5x1 conv
Branch 4: 3x1 maxpool -> 1x1 conv
"""
def __init__(self, in_channels):
super().__init__()
# Branch 1: 1x1 conv
self.branch1 = nn.Sequential(
nn.Conv1d(in_channels, 16, kernel_size=1, padding="same"),
nn.BatchNorm1d(16),
nn.ReLU(),
)
# Branch 2: 1x1 conv -> 3x1 conv
self.branch2 = nn.Sequential(
nn.Conv1d(in_channels, 16, kernel_size=1, padding="same"),
nn.BatchNorm1d(16),
nn.ReLU(),
nn.Conv1d(16, 16, kernel_size=3, padding="same"),
nn.BatchNorm1d(16),
nn.ReLU(),
)
# Branch 3: 1x1 conv -> 5x1 conv
self.branch3 = nn.Sequential(
nn.Conv1d(in_channels, 8, kernel_size=1, padding="same"),
nn.BatchNorm1d(8),
nn.ReLU(),
nn.Conv1d(8, 8, kernel_size=5, padding="same"),
nn.BatchNorm1d(8),
nn.ReLU(),
)
# Branch 4: 3x1 maxpool -> 1x1 conv
self.branch4_pool = nn.MaxPool1d(kernel_size=3, stride=1, padding=1)
self.branch4_conv = nn.Sequential(
nn.Conv1d(in_channels, 8, kernel_size=1, padding="same"),
nn.BatchNorm1d(8),
nn.ReLU(),
)
[docs]
def forward(self, x):
# x: (B, C, L)
b1 = self.branch1(x)
b2 = self.branch2(x)
b3 = self.branch3(x)
pooled = self.branch4_pool(x)
b4 = self.branch4_conv(pooled)
out = torch.cat([b1, b2, b3, b4], dim=1)
return out
[docs]
class Inception1DNet(nn.Module):
def __init__(self, in_channels: int, num_classes: int):
"""
Parameters
----------
in_channels : int
Number of input channels.
num_classes : int, optional
Number of output classes.
"""
super().__init__()
self.initial_processing = nn.Sequential(
nn.Conv1d(in_channels, 32, kernel_size=7, stride=2, padding="same"),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
)
self.inception = Inception1DBlock(in_channels=32)
# Outputs: 16 + 16 + 8 + 8 = 48 channels
# "reduce_conv" after concatenation
self.reduce = nn.Sequential(
nn.Conv1d(48, 32, kernel_size=1, padding="same"),
nn.BatchNorm1d(32),
nn.ReLU(),
)
self.final_processing = nn.Sequential(
nn.AdaptiveAvgPool1d(1),
nn.Flatten(-2, -1),
nn.Linear(32, 100),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(100, num_classes),
)
[docs]
def forward(self, x):
"""
x shape: (Batch_size, Channels, Length)
"""
x = self.initial_processing(x)
x = self.inception1(x)
x = self.reduce(x)
x = self.final_processing(x)
return x