Source code for qumphy.models.s4_model

"""
File: qumphy/models/s4_model.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: adapted from https://github.com/HazyResearch/state-spaces/blob/main/example.py .
"""

__all__ = ["S4Model"]

import torch.nn as nn

from qumphy.models.s42 import S4 as S42


[docs] class S4Model(nn.Module): """Stacked S4 model for sequence modeling. This model uses an optional input encoder, multiple S4 residual blocks, optional pooling over the sequence dimension, and an optional output decoder. """ def __init__( self, d_input, d_output, d_state=64, # MODIFIED: N d_model=512, # MODIFIED: H n_layers=4, dropout=0.2, prenorm=False, l_max=1024, transposed_input=True, # behaves like 1d CNN if True else like a RNN with batch_first=True bidirectional=True, # MODIFIED layer_norm=True, # MODIFIED pooling=True, # MODIFIED ): """Initialize the S4 model. Parameters ---------- d_input : int or None Number of input channels or features. If None, the encoder is disabled and the input is passed through unchanged. d_output : int or None Number of output channels or features. If None, the decoder is disabled. d_state : int State dimension of the S4 layers. d_model : int Hidden model dimension used inside the S4 layers. n_layers : int Number of stacked S4 layers. dropout : float Dropout probability used after each S4 layer. prenorm : bool If True, apply normalization before each S4 layer. If False, apply normalization after the residual connection. l_max : int Maximum sequence length used by the S4 layers. transposed_input : bool If True, the input is expected to have shape (batch_size, d_input, sequence_length). If False, the input is expected to have shape (batch_size, sequence_length, d_input). bidirectional : bool If True, use bidirectional S4 layers. layer_norm : bool If True, use layer normalization. If False, use batch normalization. pooling : bool If True, apply average pooling over the sequence length before the decoder. """ super().__init__() self.prenorm = prenorm # Linear encoder (d_input = 1 for grayscale and 3 for RGB) self.transposed_input = transposed_input # MODIFIED TO ALLOW FOR MODELS WITHOUT ENCODER if d_input is None: self.encoder = nn.Identity() else: self.encoder = ( nn.Conv1d(d_input, d_model, 1) if transposed_input else nn.Linear(d_input, d_model) ) # Stack S4 layers as residual blocks self.s4_layers = nn.ModuleList() self.norms = nn.ModuleList() self.dropouts = nn.ModuleList() for _ in range(n_layers): self.s4_layers.append( S42( d_state=d_state, l_max=l_max, d_model=d_model, bidirectional=bidirectional, postact="glu", dropout=dropout, transposed=True, ) ) # MODIFIED TO ALLOW BATCH NORM MODELS self.layer_norm = layer_norm if layer_norm: self.norms.append(nn.LayerNorm(d_model)) else: # MODIFIED self.norms.append(nn.BatchNorm1d(d_model)) self.dropouts.append(nn.Dropout2d(dropout)) self.pooling = pooling # Linear decoder # MODIFIED TO ALLOW FOR MODELS WITHOUT DECODER if d_output is None: self.decoder = None else: self.decoder = nn.Linear(d_model, d_output) # MODIFIED
[docs] def forward(self, x, rate=1.0): """Run a forward pass through the S4 model. Parameters ---------- x : torch.Tensor Input tensor. If transposed_input is True, the shape is (batch_size, d_input, sequence_length). Otherwise, the shape is (batch_size, sequence_length, d_input). rate : float Sampling rate factor passed to the S4 layers. Returns ------- torch.Tensor Model output. If pooling is True and the decoder is enabled, the output has shape (batch_size, d_output). If pooling is False, the output keeps the sequence dimension. """ x = self.encoder(x) if self.transposed_input is False: x = x.transpose(-1, -2) for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts): # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L) z = x if self.prenorm: # Prenorm # MODIFIED z = ( norm(z.transpose(-1, -2)).transpose(-1, -2) if self.layer_norm else norm(z) ) # Apply S4 block: we ignore the state input and output # MODIFIED z, _ = layer(z, rate=rate) # Dropout on the output of the S4 block z = dropout(z) # Residual connection x = z + x if not self.prenorm: # Postnorm # MODIFIED x = ( norm(x.transpose(-1, -2)).transpose(-1, -2) if self.layer_norm else norm(z) ) x = x.transpose(-1, -2) # (B, d_model, L) -> (B, L, d_model) # MODIFIED ALLOW TO DISABLE POOLING if self.pooling: # Pooling: average pooling over the sequence length x = x.mean(dim=1) # Decode the outputs if self.decoder is not None: x = self.decoder( x ) # (B, d_model) -> (B, d_output) if pooling else (B, L, d_model) -> (B, L, d_output) if not self.pooling and self.transposed_input is True: x = x.transpose(-1, -2) # (B, L, d_output) -> (B, d_output, L) return x