"""
File: qumphy/models/ppnet.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: PPNet model implementation from 10.1109/JSEN.2020.2990864.
"""
import torch.nn as nn
[docs]
class PPNet(nn.Module):
"""CNN-LSTM network for one-dimensional sequence classification."""
def __init__(self, input_length, num_classes=2, in_channels=1):
"""Initialize the PPNet model.
Parameters
----------
input_length : int
Length of the one-dimensional input sequence.
num_classes : int
Number of output classes.
"""
super(PPNet, self).__init__()
# Convolutional layers
self.CNN = nn.Sequential(
nn.Conv1d(in_channels=in_channels, out_channels=20, kernel_size=9),
nn.ReLU(),
nn.MaxPool1d(kernel_size=4),
nn.Dropout(0.1),
nn.Conv1d(in_channels=20, out_channels=20, kernel_size=9),
nn.ReLU(),
nn.MaxPool1d(kernel_size=4),
nn.Dropout(0.1),
)
def calc_out_len(L, kernel, pool, repeat):
"""Calculate the output sequence length after convolution and pooling.
Parameters
----------
L : int
Input sequence length.
kernel : int
Convolution kernel size.
pool : int
Pooling kernel size.
repeat : int
Number of repeated convolution and pooling blocks.
Returns
-------
int
Output sequence length after repeated convolution and pooling.
"""
for _ in range(repeat):
L = L - (kernel - 1)
L = L // pool
return L
self.seq_len = calc_out_len(input_length, kernel=9, pool=4, repeat=2)
# LSTM layers
self.lstm1 = nn.LSTM(
input_size=20, hidden_size=64, num_layers=1, batch_first=True
)
self.dropout_lstm1 = nn.Dropout(0.1)
self.lstm2 = nn.LSTM(
input_size=64, hidden_size=128, num_layers=1, batch_first=True
)
self.dropout_lstm2 = nn.Dropout(0.1)
self.fc = nn.Linear(128, num_classes)
[docs]
def forward(self, x):
"""Run a forward pass through the PPNet model.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, 1, input_length).
Returns
-------
torch.Tensor
Output logits of shape (batch_size, num_classes).
"""
x = self.CNN(x)
# Prepare for LSTM: (batch, channels=20, seq_len) -> (batch, seq_len, channels)
x = x.permute(0, 2, 1) # Now (batch, seq_len, features)
x, _ = self.lstm1(x)
x = self.dropout_lstm1(x)
x, _ = self.lstm2(x)
x = self.dropout_lstm2(x)
# Use the last output time-step of the last LSTM
x = x[:, -1, :] # (batch, hidden_size=128)
x = self.fc(x) # (batch, 3)
return x