qumphy.models.s4_model module
File: qumphy/models/s4_model.py Project: 22HLT01 QUMPHY Contact: oskar.pfeffer@ptb.de Gitlab: https://gitlab.com/qumphy Description: adapted from https://github.com/HazyResearch/state-spaces/blob/main/example.py .
- class qumphy.models.s4_model.S4Model(*args: Any, **kwargs: Any)[source]
Bases:
ModuleStacked S4 model for sequence modeling.
This model uses an optional input encoder, multiple S4 residual blocks, optional pooling over the sequence dimension, and an optional output decoder.
- forward(x, rate=1.0)[source]
Run a forward pass through the S4 model.
- Parameters:
x (torch.Tensor) – Input tensor. If transposed_input is True, the shape is (batch_size, d_input, sequence_length). Otherwise, the shape is (batch_size, sequence_length, d_input).
rate (float) – Sampling rate factor passed to the S4 layers.
- Returns:
Model output. If pooling is True and the decoder is enabled, the output has shape (batch_size, d_output). If pooling is False, the output keeps the sequence dimension.
- Return type:
torch.Tensor