qumphy.models.s4_model module

File: qumphy/models/s4_model.py Project: 22HLT01 QUMPHY Contact: oskar.pfeffer@ptb.de Gitlab: https://gitlab.com/qumphy Description: adapted from https://github.com/HazyResearch/state-spaces/blob/main/example.py .

class qumphy.models.s4_model.S4Model(*args: Any, **kwargs: Any)[source]

Bases: Module

Stacked S4 model for sequence modeling.

This model uses an optional input encoder, multiple S4 residual blocks, optional pooling over the sequence dimension, and an optional output decoder.

forward(x, rate=1.0)[source]

Run a forward pass through the S4 model.

Parameters:
  • x (torch.Tensor) – Input tensor. If transposed_input is True, the shape is (batch_size, d_input, sequence_length). Otherwise, the shape is (batch_size, sequence_length, d_input).

  • rate (float) – Sampling rate factor passed to the S4 layers.

Returns:

Model output. If pooling is True and the decoder is enabled, the output has shape (batch_size, d_output). If pooling is False, the output keeps the sequence dimension.

Return type:

torch.Tensor