"""
File: qumphy/models/s4_model.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: adapted from https://github.com/HazyResearch/state-spaces/blob/main/example.py .
"""
__all__ = ["S4Model"]
import torch.nn as nn
from qumphy.models.s42 import S4 as S42
[docs]
class S4Model(nn.Module):
"""Stacked S4 model for sequence modeling.
This model uses an optional input encoder, multiple S4 residual blocks,
optional pooling over the sequence dimension, and an optional output decoder.
"""
def __init__(
self,
d_input,
d_output,
d_state=64, # MODIFIED: N
d_model=512, # MODIFIED: H
n_layers=4,
dropout=0.2,
prenorm=False,
l_max=1024,
transposed_input=True, # behaves like 1d CNN if True else like a RNN with batch_first=True
bidirectional=True, # MODIFIED
layer_norm=True, # MODIFIED
pooling=True, # MODIFIED
):
"""Initialize the S4 model.
Parameters
----------
d_input : int or None
Number of input channels or features. If None, the encoder is
disabled and the input is passed through unchanged.
d_output : int or None
Number of output channels or features. If None, the decoder is
disabled.
d_state : int
State dimension of the S4 layers.
d_model : int
Hidden model dimension used inside the S4 layers.
n_layers : int
Number of stacked S4 layers.
dropout : float
Dropout probability used after each S4 layer.
prenorm : bool
If True, apply normalization before each S4 layer. If False, apply
normalization after the residual connection.
l_max : int
Maximum sequence length used by the S4 layers.
transposed_input : bool
If True, the input is expected to have shape
(batch_size, d_input, sequence_length). If False, the input is
expected to have shape (batch_size, sequence_length, d_input).
bidirectional : bool
If True, use bidirectional S4 layers.
layer_norm : bool
If True, use layer normalization. If False, use batch normalization.
pooling : bool
If True, apply average pooling over the sequence length before the
decoder.
"""
super().__init__()
self.prenorm = prenorm
# Linear encoder (d_input = 1 for grayscale and 3 for RGB)
self.transposed_input = transposed_input
# MODIFIED TO ALLOW FOR MODELS WITHOUT ENCODER
if d_input is None:
self.encoder = nn.Identity()
else:
self.encoder = (
nn.Conv1d(d_input, d_model, 1)
if transposed_input
else nn.Linear(d_input, d_model)
)
# Stack S4 layers as residual blocks
self.s4_layers = nn.ModuleList()
self.norms = nn.ModuleList()
self.dropouts = nn.ModuleList()
for _ in range(n_layers):
self.s4_layers.append(
S42(
d_state=d_state,
l_max=l_max,
d_model=d_model,
bidirectional=bidirectional,
postact="glu",
dropout=dropout,
transposed=True,
)
)
# MODIFIED TO ALLOW BATCH NORM MODELS
self.layer_norm = layer_norm
if layer_norm:
self.norms.append(nn.LayerNorm(d_model))
else: # MODIFIED
self.norms.append(nn.BatchNorm1d(d_model))
self.dropouts.append(nn.Dropout2d(dropout))
self.pooling = pooling
# Linear decoder
# MODIFIED TO ALLOW FOR MODELS WITHOUT DECODER
if d_output is None:
self.decoder = None
else:
self.decoder = nn.Linear(d_model, d_output)
# MODIFIED
[docs]
def forward(self, x, rate=1.0):
"""Run a forward pass through the S4 model.
Parameters
----------
x : torch.Tensor
Input tensor. If transposed_input is True, the shape is
(batch_size, d_input, sequence_length). Otherwise, the shape is
(batch_size, sequence_length, d_input).
rate : float
Sampling rate factor passed to the S4 layers.
Returns
-------
torch.Tensor
Model output. If pooling is True and the decoder is enabled, the
output has shape (batch_size, d_output). If pooling is False, the
output keeps the sequence dimension.
"""
x = self.encoder(x)
if self.transposed_input is False:
x = x.transpose(-1, -2)
for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts):
# Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L)
z = x
if self.prenorm:
# Prenorm
# MODIFIED
z = (
norm(z.transpose(-1, -2)).transpose(-1, -2)
if self.layer_norm
else norm(z)
)
# Apply S4 block: we ignore the state input and output
# MODIFIED
z, _ = layer(z, rate=rate)
# Dropout on the output of the S4 block
z = dropout(z)
# Residual connection
x = z + x
if not self.prenorm:
# Postnorm
# MODIFIED
x = (
norm(x.transpose(-1, -2)).transpose(-1, -2)
if self.layer_norm
else norm(z)
)
x = x.transpose(-1, -2) # (B, d_model, L) -> (B, L, d_model)
# MODIFIED ALLOW TO DISABLE POOLING
if self.pooling:
# Pooling: average pooling over the sequence length
x = x.mean(dim=1)
# Decode the outputs
if self.decoder is not None:
x = self.decoder(
x
) # (B, d_model) -> (B, d_output) if pooling else (B, L, d_model) -> (B, L, d_output)
if not self.pooling and self.transposed_input is True:
x = x.transpose(-1, -2) # (B, L, d_output) -> (B, d_output, L)
return x