qumphy.models.s42 module

File: qumphy/models/s42.py Project: 22HLT01 QUMPHY Contact: oskar.pfeffer@ptb.de Gitlab: https://gitlab.com/qumphy Description: Standalone version of Structured (Sequence) State Space (S4) model.

qumphy.models.s42.Activation(activation=None, dim=-1)[source]

Compute the Cauchy multiplication using PyKeOps.

Parameters:
  • v (torch.Tensor) – Numerator tensor.

  • z (torch.Tensor) – Evaluation points.

  • w (torch.Tensor) – Complex poles.

Returns:

Result of the Cauchy multiplication.

Return type:

torch.Tensor

class qumphy.models.s42.HippoSSKernel(*args: Any, **kwargs: Any)[source]

Bases: Module

Wrapper around SSKernel that generates A, B, C, dt according to HiPPO arguments.

The SSKernel is expected to support the interface forward() default_state() setup_step() step()

default_state(*args, **kwargs)[source]
forward(L=None, rate=1.0)[source]
step(u, state, **kwargs)[source]
qumphy.models.s42.LinearActivation(d_input, d_output, bias=True, zero_bias_init=False, transposed=False, initializer=None, activation=None, activate=False, weight_norm=False, **kwargs)[source]

Create a linear layer with optional initialization and activation.

Parameters:
  • d_input (int) – Number of input features.

  • d_output (int) – Number of output features.

  • bias (bool) – If True, include a bias parameter.

  • zero_bias_init (bool) – If True, initialize the bias with zeros.

  • transposed (bool) – If True, use TransposedLinear instead of nn.Linear.

  • initializer (str) – Name of the weight initializer.

  • activation (str) – Name of the activation function.

  • activate (bool) – If True, append the activation function to the linear layer.

  • weight_norm (bool) – If True, apply weight normalization to the linear layer.

  • **kwargs – Additional keyword arguments passed to the linear layer.

Returns:

Linear module, optionally followed by an activation function.

Return type:

nn.Module

class qumphy.models.s42.S4(*args: Any, **kwargs: Any)[source]

Bases: Module

Structured State Space Sequence layer.

property d_output
property d_state
default_state(*batch_shape, device=None)[source]
forward(u, rate=1.0, **kwargs)[source]

Run a forward pass through the S4 layer.

Parameters:
  • u (torch.Tensor) – Input tensor. If transposed is True, the shape is (batch_size, d_model, sequence_length). Otherwise, the shape is (batch_size, sequence_length, d_model).

  • rate (float) – Sampling rate factor.

  • **kwargs – Additional keyword arguments kept for compatibility.

Returns:

Tuple containing the output tensor and None for compatibility with recurrent interfaces.

Return type:

tuple

property state_to_tensor
step(u, state)[source]

Step one time step as a recurrent model. Intended to be used during validation.

u: (B H) state: (B H N) Returns: output (B H), state (B H N)

class qumphy.models.s42.SSKernelNPLR(*args: Any, **kwargs: Any)[source]

Bases: Module

Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)

The class name stands for ‘State-Space SSKernel for Normal Plus Low-Rank’. The parameters of this function are as follows.

A: (… N N) the state matrix B: (… N) input matrix C: (… N) output matrix dt: (…) timescales / discretization step size p, q: (… P N) low-rank correction to A, such that Ap=A+pq^T is a normal matrix

The forward pass of this Module returns: (… L) that represents represents FFT SSKernel_L(A^dt, B^dt, C)

default_state(*batch_shape)[source]

Create an initial recurrent state.

Parameters:

*batch_shape (int) – Batch dimensions of the state.

Returns:

Zero-initialized recurrent state.

Return type:

torch.Tensor

double_length()

Double the internal kernel length.

Returns:

The function updates the internal length and cached FFT nodes.

Return type:

None

forward(state=None, rate=1.0, L=None)[source]

state: (…, s, N) extra tensor that augments B rate: sampling rate factor

returns: (…, c+s, L)

register(name, tensor, trainable=False, lr=None, wd=None)[source]

Register a tensor as a parameter or buffer.

Parameters:
  • name (str) – Name used to register the tensor.

  • tensor (torch.Tensor) – Tensor to register.

  • trainable (bool) – If True, register the tensor as a trainable parameter. Otherwise, register it as a buffer.

  • lr (float) – Optional learning rate metadata.

  • wd (float) – Optional weight decay metadata.

Returns:

The function registers the tensor in the module.

Return type:

None

setup_step(mode='dense')[source]

Set up dA, dB, dC discretized parameters for stepping

step(u, state)[source]

Must have called self.setup_step() and created state with self.default_state() before calling this

class qumphy.models.s42.TransposedLinear(*args: Any, **kwargs: Any)[source]

Bases: Module

Linear module on the second-to-last dimension

forward(x)[source]

Run a forward pass through the transposed linear layer.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

Output tensor after applying the linear transformation.

Return type:

torch.Tensor

qumphy.models.s42.bilinear(dt, A, B=None)[source]

Apply bilinear discretization to a continuous state-space system.

Parameters:
  • dt (torch.Tensor) – Time step or timescale tensor.

  • A (torch.Tensor) – Continuous transition matrix.

  • B (torch.Tensor) – Optional continuous input matrix.

Returns:

Discretized transition matrix and discretized input matrix.

Return type:

tuple

qumphy.models.s42.cauchy_conj(v, z, w)[source]

Compute the Cauchy multiplication using PyKeOps.

Parameters:
  • v (torch.Tensor) – Numerator tensor.

  • z (torch.Tensor) – Evaluation points.

  • w (torch.Tensor) – Complex poles.

Returns:

Result of the Cauchy multiplication.

Return type:

torch.Tensor

qumphy.models.s42.embed_c2r(A)[source]
qumphy.models.s42.get_initializer(name, activation=None)[source]

Get a weight initialization function.

Parameters:
  • name (str) – Name of the initializer. Supported values are “uniform”, “normal”, “xavier”, “zero”, and “one”.

  • activation (str) – Activation function used to determine the initialization gain.

Returns:

Weight initialization function.

Return type:

callable

qumphy.models.s42.krylov(L, A, b, c=None, return_power=False)[source]

Compute a Krylov sequence.

Parameters:
  • L (int) – Length of the Krylov sequence.

  • A (torch.Tensor) – Square transition matrix.

  • b (torch.Tensor) – Initial vector.

  • c (torch.Tensor) – Optional projection vector.

  • return_power (bool) – If True, also return A raised to the power L - 1.

Returns:

Krylov sequence, optionally together with A raised to the power L - 1.

Return type:

torch.Tensor or tuple

qumphy.models.s42.nplr(measure, N, rank=1, dtype=torch.float)[source]

Convert a HiPPO matrix into normal plus low-rank form.

Parameters:
  • measure (str) – Type of HiPPO measure.

  • N (int) – State dimension.

  • rank (int) – Rank of the low-rank correction.

  • dtype (torch.dtype) – Floating point data type.

Returns:

Tuple containing eigenvalues, low-rank correction, input vector, and eigenvector matrix.

Return type:

tuple

qumphy.models.s42.power(L, A, v=None)[source]

Compute a matrix power and optional scan reduction.

Parameters:
  • L (int) – Power to which the matrix is raised.

  • A (torch.Tensor) – Square matrix of shape (…, N, N).

  • v (torch.Tensor) – Optional tensor used to compute the scan sum over powers of A.

Returns:

If v is None, returns A raised to the power L. Otherwise, returns the matrix power and the scan reduction.

Return type:

torch.Tensor or tuple

qumphy.models.s42.rank_correction(measure, N, rank=1, dtype=torch.float)[source]

Return low-rank matrix L such that A + L is normal

Parameters:
  • measure (str) – Type of HiPPO measure.

  • N (int) – State dimension.

  • rank (int) – Rank of the correction matrix.

  • dtype (torch.dtype) – Data type of the returned tensor.

Returns:

Low-rank correction matrix.

Return type:

torch.Tensor

qumphy.models.s42.transition(measure, N, **measure_args)[source]

A, B transition matrices for different measures

measure: the type of measure

legt - Legendre (translated) legs - Legendre (scaled) glagt - generalized Laguerre (translated) lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization

Parameters:
  • measure (str) – Type of HiPPO measure. Supported values include “legt”, “legs”, “glagt”, “lagt”, “fourier”, “random”, and “diagonal”.

  • N (int) – State dimension.

  • **measure_args – Additional arguments for the selected measure.

Returns:

Tuple containing the transition matrix A and input matrix B.

Return type:

tuple