"""
File: qumphy/models/s42.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Standalone version of Structured (Sequence) State Space (S4) model.
"""
# version from the SSM ECG repo
# https://github.com/HazyResearch/state-spaces/blob/main/src/models/sequence/ss/standalone/s4.py
from functools import partial
import math
import numpy as np
from scipy import special as ss
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import opt_einsum as oe
from pykeops.torch import Genred
""" Cauchy kernel """
try: # Try CUDA extension
from extensions.cauchy.cauchy import cauchy_mult
has_cauchy_extension = True
except ImportError:
# log.warn(
# "CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%"
# )
has_cauchy_extension = False
def _broadcast_dims(*tensors):
"""Broadcast tensors by adding leading singleton dimensions.
Parameters
----------
*tensors : torch.Tensor
Tensors with possibly different numbers of dimensions.
Returns
-------
list
List of reshaped tensors with the same number of dimensions.
"""
max_dim = max([len(tensor.shape) for tensor in tensors])
tensors = [
tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape)
for tensor in tensors
]
return tensors
_c2r = torch.view_as_real
_r2c = torch.view_as_complex
def _conj(x):
"""Broadcast tensors by adding leading singleton dimensions.
Parameters
----------
*tensors : torch.Tensor
Tensors with possibly different numbers of dimensions.
Returns
-------
list
List of reshaped tensors with the same number of dimensions.
"""
return torch.cat([x, x.conj()], dim=-1)
if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10):
def _resolve_conj(x):
"""Concatenate a complex tensor with its conjugate.
Parameters
----------
x : torch.Tensor
Complex-valued input tensor.
Returns
-------
torch.Tensor
Tensor containing the input and its complex conjugate along the last
dimension.
"""
return x.conj().resolve_conj()
else:
def _resolve_conj(x):
"""Return the resolved complex conjugate of a tensor.
Parameters
----------
x : torch.Tensor
Complex-valued input tensor.
Returns
-------
torch.Tensor
Resolved conjugate tensor.
"""
return x.conj()
[docs]
def cauchy_conj(v, z, w):
"""Compute the Cauchy multiplication using PyKeOps.
Parameters
----------
v : torch.Tensor
Numerator tensor.
z : torch.Tensor
Evaluation points.
w : torch.Tensor
Complex poles.
Returns
-------
torch.Tensor
Result of the Cauchy multiplication.
"""
expr_num = "z * ComplexReal(v) - Real2Complex(Sum(v * w))"
expr_denom = "ComplexMult(z-w, z-Conj(w))"
cauchy_mult = Genred(
f"ComplexDivide({expr_num}, {expr_denom})",
[
"v = Vj(2)",
"z = Vi(2)",
"w = Vj(2)",
],
reduction_op="Sum",
axis=1,
)
v, z, w = _broadcast_dims(v, z, w)
v = _c2r(v)
z = _c2r(z)
w = _c2r(w)
r = 2 * cauchy_mult(v, z, w, backend="GPU")
return _r2c(r)
""" simple nn.Module components """
[docs]
def Activation(activation=None, dim=-1):
"""Compute the Cauchy multiplication using PyKeOps.
Parameters
----------
v : torch.Tensor
Numerator tensor.
z : torch.Tensor
Evaluation points.
w : torch.Tensor
Complex poles.
Returns
-------
torch.Tensor
Result of the Cauchy multiplication.
"""
if activation in [None, "id", "identity", "linear"]:
return nn.Identity()
elif activation == "tanh":
return nn.Tanh()
elif activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
elif activation in ["swish", "silu"]:
return nn.SiLU()
elif activation == "glu":
return nn.GLU(dim=dim)
elif activation == "sigmoid":
return nn.Sigmoid()
else:
raise NotImplementedError(
"hidden activation '{}' is not implemented".format(activation)
)
[docs]
def get_initializer(name, activation=None):
"""Get a weight initialization function.
Parameters
----------
name : str
Name of the initializer. Supported values are "uniform", "normal",
"xavier", "zero", and "one".
activation : str
Activation function used to determine the initialization gain.
Returns
-------
callable
Weight initialization function.
"""
if activation in [None, "id", "identity", "linear", "modrelu"]:
nonlinearity = "linear"
elif activation in ["relu", "tanh", "sigmoid"]:
nonlinearity = activation
elif activation in ["gelu", "swish", "silu"]:
nonlinearity = "relu" # Close to ReLU so approximate with ReLU's gain
else:
raise NotImplementedError(
f"get_initializer: activation {activation} not supported"
)
if name == "uniform":
initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity)
elif name == "normal":
initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity)
elif name == "xavier":
initializer = torch.nn.init.xavier_normal_
elif name == "zero":
initializer = partial(torch.nn.init.constant_, val=0)
elif name == "one":
initializer = partial(torch.nn.init.constant_, val=1)
else:
raise NotImplementedError(
f"get_initializer: initializer type {name} not supported"
)
return initializer
[docs]
class TransposedLinear(nn.Module):
"""Linear module on the second-to-last dimension"""
def __init__(self, d_input, d_output, bias=True):
"""Initialize the transposed linear layer.
Parameters
----------
d_input : int
Number of input features.
d_output : int
Number of output features.
bias : bool
If True, include a bias parameter.
"""
super().__init__()
self.weight = nn.Parameter(torch.empty(d_output, d_input))
# nn.Linear default init
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
# nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent
if bias:
self.bias = nn.Parameter(torch.empty(d_output, 1))
bound = 1 / math.sqrt(d_input)
nn.init.uniform_(self.bias, -bound, bound)
else:
self.bias = 0.0
[docs]
def forward(self, x):
"""Run a forward pass through the transposed linear layer.
Parameters
----------
x : torch.Tensor
Input tensor.
Returns
-------
torch.Tensor
Output tensor after applying the linear transformation.
"""
return oe.contract("... u l, v u -> ... v l", x, self.weight) + self.bias
[docs]
def LinearActivation(
d_input,
d_output,
bias=True,
zero_bias_init=False,
transposed=False,
initializer=None,
activation=None,
activate=False, # Apply activation as part of this module
weight_norm=False,
**kwargs,
):
"""Create a linear layer with optional initialization and activation.
Parameters
----------
d_input : int
Number of input features.
d_output : int
Number of output features.
bias : bool
If True, include a bias parameter.
zero_bias_init : bool
If True, initialize the bias with zeros.
transposed : bool
If True, use TransposedLinear instead of nn.Linear.
initializer : str
Name of the weight initializer.
activation : str
Name of the activation function.
activate : bool
If True, append the activation function to the linear layer.
weight_norm : bool
If True, apply weight normalization to the linear layer.
**kwargs
Additional keyword arguments passed to the linear layer.
Returns
-------
nn.Module
Linear module, optionally followed by an activation function.
"""
# Construct core module
linear_cls = TransposedLinear if transposed else nn.Linear
if activation == "glu":
d_output *= 2
linear = linear_cls(d_input, d_output, bias=bias, **kwargs)
# Initialize weight
if initializer is not None:
get_initializer(initializer, activation)(linear.weight)
# Initialize bias
if bias and zero_bias_init:
nn.init.zeros_(linear.bias)
# Weight norm
if weight_norm:
linear = nn.utils.weight_norm(linear)
if activate and activation is not None:
activation = Activation(activation, dim=-2 if transposed else -1)
linear = nn.Sequential(linear, activation)
return linear
""" Misc functional utilities """
[docs]
def krylov(L, A, b, c=None, return_power=False):
"""Compute a Krylov sequence.
Parameters
----------
L : int
Length of the Krylov sequence.
A : torch.Tensor
Square transition matrix.
b : torch.Tensor
Initial vector.
c : torch.Tensor
Optional projection vector.
return_power : bool
If True, also return A raised to the power L - 1.
Returns
-------
torch.Tensor or tuple
Krylov sequence, optionally together with A raised to the power L - 1.
"""
# TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises
x = b.unsqueeze(-1) # (..., N, 1)
A_ = A
AL = None
if return_power:
AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device)
_L = L - 1
done = L == 1
# loop invariant: _L represents how many indices left to compute
while not done:
if return_power:
if _L % 2 == 1:
AL = A_ @ AL
_L //= 2
# Save memory on last iteration
length = x.shape[-1]
if L - length <= length:
done = True
_x = x[..., : L - length]
else:
_x = x
_x = A_ @ _x
# there might be a more efficient way of ordering axes
x = torch.cat([x, _x], dim=-1)
if not done:
A_ = A_ @ A_
assert x.shape[-1] == L
if c is not None:
x = torch.einsum("...nl, ...n -> ...l", x, c)
x = x.contiguous() # WOW!!
if return_power:
return x, AL
else:
return x
[docs]
def power(L, A, v=None):
"""Compute a matrix power and optional scan reduction.
Parameters
----------
L : int
Power to which the matrix is raised.
A : torch.Tensor
Square matrix of shape (..., N, N).
v : torch.Tensor
Optional tensor used to compute the scan sum over powers of A.
Returns
-------
torch.Tensor or tuple
If v is None, returns A raised to the power L. Otherwise, returns
the matrix power and the scan reduction.
"""
Id = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device)
powers = [A]
largest_pow = 1
while True:
if L % 2 == 1:
Id = powers[-1] @ Id
L //= 2
if L == 0:
break
largest_pow *= 2
powers.append(powers[-1] @ powers[-1])
if v is None:
return Id
# Invariants:
# powers[-1] := A^l
# l := largest po2 at most L
# Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A
# We do this reverse divide-and-conquer for efficiency reasons:
# 1) it involves fewer padding steps for non-po2 L
# 2) it involves more contiguous arrays
# Take care of edge case for non-po2 arrays
# Note that this initial step is a no-op for the case of power of 2 (l == L)
k = v.size(-1) - largest_pow
v_ = powers.pop() @ v[..., largest_pow:]
v = v[..., :largest_pow]
v[..., :k] = v[..., :k] + v_
# Handle reduction for power of 2
while v.size(-1) > 1:
v = rearrange(v, "... (z l) -> ... z l", z=2)
v = v[..., 0, :] + powers.pop() @ v[..., 1, :]
return Id, v.squeeze(-1)
""" HiPPO utilities """
[docs]
def embed_c2r(A):
A = rearrange(A, "... m n -> ... m () n ()")
A = np.pad(A, ((0, 0), (0, 1), (0, 0), (0, 1))) + np.pad(
A, ((0, 0), (1, 0), (0, 0), (1, 0))
)
return rearrange(A, "m x n y -> (m x) (n y)")
[docs]
def transition(measure, N, **measure_args):
"""A, B transition matrices for different measures
measure: the type of measure
legt - Legendre (translated)
legs - Legendre (scaled)
glagt - generalized Laguerre (translated)
lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization
Parameters
----------
measure : str
Type of HiPPO measure. Supported values include "legt", "legs",
"glagt", "lagt", "fourier", "random", and "diagonal".
N : int
State dimension.
**measure_args
Additional arguments for the selected measure.
Returns
-------
tuple
Tuple containing the transition matrix A and input matrix B.
"""
# Laguerre (translated)
if measure == "lagt":
b = measure_args.get("beta", 1.0)
A = np.eye(N) / 2 - np.tril(np.ones((N, N)))
B = b * np.ones((N, 1))
# Generalized Laguerre
# alpha 0, beta small is most stable (limits to the 'lagt' measure)
# alpha 0, beta 1 has transition matrix A = [lower triangular 1]
elif measure == "glagt":
alpha = measure_args.get("alpha", 0.0)
beta = measure_args.get("beta", 0.01)
A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1)
B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None]
L = np.exp(
0.5 * (ss.gammaln(np.arange(N) + alpha + 1) - ss.gammaln(np.arange(N) + 1))
)
A = (1.0 / L[:, None]) * A * L[None, :]
B = (
(1.0 / L[:, None])
* B
* np.exp(-0.5 * ss.gammaln(1 - alpha))
* beta ** ((1 - alpha) / 2)
)
# Legendre (translated)
elif measure == "legt":
Q = np.arange(N, dtype=np.float64)
R = (2 * Q + 1) ** 0.5
j, i = np.meshgrid(Q, Q)
A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :]
B = R[:, None]
A = -A
# Legendre (scaled)
elif measure == "legs":
q = np.arange(N, dtype=np.float64)
col, row = np.meshgrid(q, q)
r = 2 * q + 1
M = -(np.where(row >= col, r, 0) - np.diag(q))
T = np.sqrt(np.diag(2 * q + 1))
A = T @ M @ np.linalg.inv(T)
B = np.diag(T)[:, None]
B = (
B.copy()
) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
elif measure == "fourier":
freqs = np.arange(N // 2)
d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1]
A = 2 * np.pi * (np.diag(d, 1) - np.diag(d, -1))
A = A - embed_c2r(np.ones((N // 2, N // 2)))
B = embed_c2r(np.ones((N // 2, 1)))[..., :1]
elif measure == "random":
A = np.random.randn(N, N) / N
B = np.random.randn(N, 1)
elif measure == "diagonal":
A = -np.diag(np.exp(np.random.randn(N)))
B = np.random.randn(N, 1)
else:
raise NotImplementedError
return A, B
[docs]
def rank_correction(measure, N, rank=1, dtype=torch.float):
"""Return low-rank matrix L such that A + L is normal
Parameters
----------
measure : str
Type of HiPPO measure.
N : int
State dimension.
rank : int
Rank of the correction matrix.
dtype : torch.dtype
Data type of the returned tensor.
Returns
-------
torch.Tensor
Low-rank correction matrix.
"""
if measure == "legs":
assert rank >= 1
P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N)
elif measure == "legt":
assert rank >= 2
P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N)
P0 = P.clone()
P0[0::2] = 0.0
P1 = P.clone()
P1[1::2] = 0.0
P = torch.stack([P0, P1], dim=0) # (2 N)
elif measure == "lagt":
assert rank >= 1
P = 0.5**0.5 * torch.ones(1, N, dtype=dtype)
elif measure == "fourier":
P = torch.ones(N, dtype=dtype) # (N)
P0 = P.clone()
P0[0::2] = 0.0
P1 = P.clone()
P1[1::2] = 0.0
P = torch.stack([P0, P1], dim=0) # (2 N)
else:
raise NotImplementedError
d = P.size(0)
if rank > d:
P = torch.cat([P, torch.zeros(rank - d, N, dtype=dtype)], dim=0) # (rank N)
return P
[docs]
def nplr(measure, N, rank=1, dtype=torch.float):
"""Convert a HiPPO matrix into normal plus low-rank form.
Parameters
----------
measure : str
Type of HiPPO measure.
N : int
State dimension.
rank : int
Rank of the low-rank correction.
dtype : torch.dtype
Floating point data type.
Returns
-------
tuple
Tuple containing eigenvalues, low-rank correction, input vector,
and eigenvector matrix.
"""
assert dtype == torch.float or torch.cfloat
if measure == "random":
dtype = torch.cfloat if dtype == torch.float else torch.cdouble
# w = torch.randn(N//2, dtype=dtype)
w = -torch.exp(torch.randn(N // 2)) + 1j * torch.randn(N // 2)
P = torch.randn(rank, N // 2, dtype=dtype)
B = torch.randn(N // 2, dtype=dtype)
V = torch.eye(N, dtype=dtype)[..., : N // 2] # Only used in testing
return w, P, B, V
A, B = transition(measure, N)
A = torch.as_tensor(A, dtype=dtype) # (N, N)
B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,)
P = rank_correction(measure, N, rank=rank, dtype=dtype)
AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3)
w, V = torch.linalg.eig(AP) # (..., N) (..., N, N)
# V w V^{-1} = A
# Only keep one of the conjugate pairs
w = w[..., 0::2].contiguous()
V = V[..., 0::2].contiguous()
V_inv = V.conj().transpose(-1, -2)
B = oe.contract("ij, j -> i", V_inv, B.to(V)) # V^* B
P = oe.contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P
return w, P, B, V
[docs]
def bilinear(dt, A, B=None):
"""Apply bilinear discretization to a continuous state-space system.
Parameters
----------
dt : torch.Tensor
Time step or timescale tensor.
A : torch.Tensor
Continuous transition matrix.
B : torch.Tensor
Optional continuous input matrix.
Returns
-------
tuple
Discretized transition matrix and discretized input matrix.
"""
N = A.shape[-1]
Id = torch.eye(N).to(A)
A_backwards = Id - dt[:, None, None] / 2 * A
A_forwards = Id + dt[:, None, None] / 2 * A
if B is None:
dB = None
else:
dB = dt[..., None] * torch.linalg.solve(A_backwards, B.unsqueeze(-1)).squeeze(
-1
) # (... N)
dA = torch.linalg.solve(A_backwards, A_forwards) # (... N N)
return dA, dB
[docs]
class SSKernelNPLR(nn.Module):
"""Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)
The class name stands for 'State-Space SSKernel for Normal Plus Low-Rank'.
The parameters of this function are as follows.
A: (... N N) the state matrix
B: (... N) input matrix
C: (... N) output matrix
dt: (...) timescales / discretization step size
p, q: (... P N) low-rank correction to A, such that Ap=A+pq^T is a normal matrix
The forward pass of this Module returns:
(... L) that represents represents FFT SSKernel_L(A^dt, B^dt, C)
"""
@torch.no_grad()
def _setup_C(self, double_length=False):
"""Construct C~ from C
double_length: current C is for length L, convert it to length 2L
"""
C = _r2c(self.C)
self._setup_state()
dA_L = power(self.L, self.dA)
# Multiply C by I - dA_L
C_ = _conj(C)
prod = oe.contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_)
if double_length:
prod = -prod # Multiply by I + dA_L instead
C_ = C_ - prod
C_ = C_[..., : self.N] # Take conjugate pairs again
self.C.copy_(_c2r(C_))
if double_length:
self.L *= 2
self._omega(self.L, dtype=C.dtype, device=C.device, cache=True)
def _omega(self, L, dtype, device, cache=True):
"""Calculate (and cache) FFT nodes and their "unprocessed" them with the bilinear transform
This should be called everytime the internal length self.L changes
Parameters
----------
L : int
Kernel length.
dtype : torch.dtype
Data type of the nodes.
device : torch.device
Device on which the nodes are created.
cache : bool
If True, register the nodes as buffers.
Returns
-------
tuple
Tuple containing omega and z nodes.
"""
omega = torch.tensor(
np.exp(-2j * np.pi / (L)), dtype=dtype, device=device
) # \omega_{2L}
omega = omega ** torch.arange(0, L // 2 + 1, device=device)
z = 2 * (1 - omega) / (1 + omega)
if cache:
self.register_buffer("omega", _c2r(omega))
self.register_buffer("z", _c2r(z))
return omega, z
def __init__(
self,
L,
w,
P,
B,
C,
log_dt,
hurwitz=False,
trainable=None,
lr=None,
tie_state=False,
length_correction=True,
verbose=False,
):
"""
L: Maximum length; this module computes an SSM kernel of length L
w: (N)
p: (r, N) low-rank correction to A
q: (r, N)
A represented by diag(w) - pq^*
B: (N)
dt: (H) timescale per feature
C: (H, C, N) system is 1-D to c-D (channels)
hurwitz: tie pq and ensure w has negative real part
trainable: toggle which of the parameters is trainable
lr: add hook to set lr of hippo parameters specially (everything besides C)
tie_state: tie all state parameters across the H hidden features
length_correction: multiply C by (I - dA^L) - can be turned off when L is large for slight speedup at initialization (only relevant when N large as well)
Note: tensor shape N here denotes half the true state size, because of conjugate symmetry
"""
super().__init__()
self.hurwitz = hurwitz
self.tie_state = tie_state
self.verbose = verbose
# Rank of low-rank correction
self.rank = P.shape[-2]
assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1)
self.H = log_dt.size(-1)
self.N = w.size(-1)
# Broadcast everything to correct shapes
C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (H, C, N)
H = 1 if self.tie_state else self.H
B = repeat(B, "n -> 1 h n", h=H)
P = repeat(P, "r n -> r h n", h=H)
w = repeat(w, "n -> h n", h=H)
# Cache Fourier nodes every time we set up a desired length
self.L = L
if self.L is not None:
self._omega(self.L, dtype=C.dtype, device=C.device, cache=True)
# Register parameters
# C is a regular parameter, not state
# self.C = nn.Parameter(_c2r(C.conj().resolve_conj()))
self.C = nn.Parameter(_c2r(_resolve_conj(C)))
train = False
if trainable is None:
trainable = {}
elif not trainable:
trainable = {}
elif trainable:
trainable, train = {}, True
self.register("log_dt", log_dt, trainable.get("dt", train), lr, 0.0)
self.register("B", _c2r(B), trainable.get("B", train), lr, 0.0)
self.register("P", _c2r(P), trainable.get("P", train), lr, 0.0)
if self.hurwitz:
# Some of the HiPPO methods have real part 0
log_w_real = torch.log(-w.real + 1e-3)
w_imag = w.imag
self.register("log_w_real", log_w_real, trainable.get("A", 0), lr, 0.0)
self.register("w_imag", w_imag, trainable.get("A", train), lr, 0.0)
self.Q = None
else:
self.register("w", _c2r(w), trainable.get("A", train), lr, 0.0)
# self.register("Q", _c2r(P.clone().conj().resolve_conj()), trainable.get('P', train), lr, 0.0)
Q = _resolve_conj(P.clone())
self.register("Q", _c2r(Q), trainable.get("P", train), lr, 0.0)
if length_correction:
self._setup_C()
def _w(self):
# Get the internal w (diagonal) parameter
if self.hurwitz:
w_real = -torch.exp(self.log_w_real)
w_imag = self.w_imag
w = w_real + 1j * w_imag
else:
w = _r2c(self.w) # (..., N)
return w
[docs]
def forward(self, state=None, rate=1.0, L=None):
"""
state: (..., s, N) extra tensor that augments B
rate: sampling rate factor
returns: (..., c+s, L)
"""
# Handle sampling rate logic
# The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) sampling rate rate
# If either are not passed in, assume we're not asked to change the scale of our kernel
assert not (rate is None and L is None)
if rate is None:
rate = self.L / L
if L is None:
L = int(self.L / rate)
# Increase the internal length if needed
while rate * L > self.L:
self.double_length()
dt = torch.exp(self.log_dt) * rate
B = _r2c(self.B)
C = _r2c(self.C)
P = _r2c(self.P)
Q = P.conj() if self.Q is None else _r2c(self.Q)
w = self._w()
if rate == 1.0:
# Use cached FFT nodes
omega, z = _r2c(self.omega), _r2c(self.z) # (..., L)
else:
omega, z = self._omega(
int(self.L / rate), dtype=w.dtype, device=w.device, cache=False
)
if self.tie_state:
B = repeat(B, "... 1 n -> ... h n", h=self.H)
P = repeat(P, "... 1 n -> ... h n", h=self.H)
Q = repeat(Q, "... 1 n -> ... h n", h=self.H)
# Augment B
if state is not None:
# Have to "unbilinear" the state to put it into the same "type" as B
# Compute 1/dt * (I + dt/2 A) @ state
# Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way
s = _conj(state) if state.size(-1) == self.N else state # (B H N)
sA = s * _conj(w) - oe.contract( # (B H N)
"bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P)
)
s = s / dt.unsqueeze(-1) + sA / 2
s = s[..., : self.N]
B = torch.cat([s, B], dim=-3) # (s+1, H, N)
# Incorporate dt into A
w = w * dt.unsqueeze(-1) # (H N)
# Stack B and p, C and q for convenient batching
B = torch.cat([B, P], dim=-3) # (s+1+r, H, N)
C = torch.cat([C, Q], dim=-3) # (c+r, H, N)
# Incorporate B and C batch dimensions
v = B.unsqueeze(-3) * C.unsqueeze(-4) # (s+1+r, c+r, H, N)
# w = w[None, None, ...] # (1, 1, H, N)
# z = z[None, None, None, ...] # (1, 1, 1, L)
# Calculate resolvent at omega
if has_cauchy_extension and z.dtype == torch.cfloat:
r = cauchy_mult(v, z, w, symmetric=True)
else:
r = cauchy_conj(v, z, w)
r = r * dt[None, None, :, None] # (S+1+R, C+R, H, L)
# Low-rank Woodbury correction
if self.rank == 1:
k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (
1 + r[-1:, -1:, :, :]
)
elif self.rank == 2:
r00 = r[: -self.rank, : -self.rank, :, :]
r01 = r[: -self.rank, -self.rank :, :, :]
r10 = r[-self.rank :, : -self.rank, :, :]
r11 = r[-self.rank :, -self.rank :, :, :]
det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[
:1, 1:, :, :
] * r11[1:, :1, :, :]
s = (
r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :]
+ r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :]
- r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :]
- r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :]
)
s = s / det
k_f = r00 - s
else:
r00 = r[: -self.rank, : -self.rank, :, :]
r01 = r[: -self.rank, -self.rank :, :, :]
r10 = r[-self.rank :, : -self.rank, :, :]
r11 = r[-self.rank :, -self.rank :, :, :]
r11 = rearrange(r11, "a b h n -> h n a b")
r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11)
r11 = rearrange(r11, "h n a b -> a b h n")
k_f = r00 - torch.einsum(
"i j h n, j k h n, k l h n -> i l h n", r01, r11, r10
)
# Final correction for the bilinear transform
k_f = k_f * 2 / (1 + omega)
# Move from frequency to coefficients
k = torch.fft.irfft(k_f) # (S+1, C, H, L)
# Truncate to target length
k = k[..., :L]
if state is not None:
k_state = k[:-1, :, :, :] # (S, C, H, L)
else:
k_state = None
k_B = k[-1, :, :, :] # (C H L)
return k_B, k_state
@torch.no_grad()
def double_length(self):
"""Double the internal kernel length.
Returns
-------
None
The function updates the internal length and cached FFT nodes.
"""
self._setup_C(double_length=True)
def _setup_linear(self):
"""Set up parameters for fast linear recurrent stepping.
Returns
-------
None
The function stores stepping parameters in self.step_params.
"""
w = self._w()
B = _r2c(self.B) # (H N)
P = _r2c(self.P)
Q = P.conj() if self.Q is None else _r2c(self.Q)
# Prepare Linear stepping
dt = torch.exp(self.log_dt)
D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N)
R = (
torch.eye(self.rank, dtype=w.dtype, device=w.device)
+ 2 * oe.contract("r h n, h n, s h n -> h r s", Q, D, P).real
) # (H r r)
Q_D = rearrange(Q * D, "r h n -> h r n")
R = torch.linalg.solve(R.to(Q_D), Q_D) # (H r N)
R = rearrange(R, "h r n -> r h n")
self.step_params = {
"D": D, # (H N)
"R": R, # (r H N)
"P": P, # (r H N)
"Q": Q, # (r H N)
"B": B, # (1 H N)
"E": 2.0 / dt.unsqueeze(-1) + w, # (H N)
}
def _step_state_linear(self, u=None, state=None):
"""
Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization.
Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster
u: (H) input
state: (H, N/2) state with conjugate pairs
Optionally, the state can have last dimension N
Returns: same shape as state
"""
C = _r2c(self.C) # View used for dtype/device
if u is None: # Special case used to find dA
u = torch.zeros(self.H, dtype=C.dtype, device=C.device)
if state is None: # Special case used to find dB
state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device)
step_params = self.step_params.copy()
# Only store half of the conjugate pairs; should be true by default
if state.size(-1) == self.N:
# There should be a slightly faster way using conjugate symmetry
def contract_fn(p, x, y):
return oe.contract(
"r h n, r h m, ... h m -> ... h n", _conj(p), _conj(x), _conj(y)
)[
..., : self.N
] # inner outer product
else:
assert state.size(-1) == 2 * self.N
step_params = {k: _conj(v) for k, v in step_params.items()}
# TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping
def contract_fn(p, x, y):
return oe.contract(
"r h n, r h m, ... h m -> ... h n", p, x, y
) # inner outer product
D = step_params["D"] # (H N)
E = step_params["E"] # (H N)
R = step_params["R"] # (r H N)
P = step_params["P"] # (r H N)
Q = step_params["Q"] # (r H N)
B = step_params["B"] # (1 H N)
new_state = E * state - contract_fn(P, Q, state) # (B H N)
new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N)
new_state = D * (new_state - contract_fn(P, R, new_state))
return new_state
def _setup_state(self):
"""Construct dA and dB for discretized state equation"""
# Construct dA and dB by using the stepping
self._setup_linear()
# Just returns a view that we use for finding dtype/device
C = _r2c(self.C)
state = torch.eye(2 * self.N, dtype=C.dtype, device=C.device).unsqueeze(
-2
) # (N 1 N)
dA = self._step_state_linear(state=state)
dA = rearrange(dA, "n h m -> h m n")
self.dA = dA # (H N N)
u = C.new_ones(self.H)
dB = self._step_state_linear(u=u)
dB = _conj(dB)
self.dB = rearrange(dB, "1 h n -> h n") # (H N)
def _step_state(self, u, state):
"""Must be called after self.default_state() is used to construct an initial state!
Parameters
----------
u : torch.Tensor
Input tensor for the current time step.
state : torch.Tensor
Current recurrent state.
Returns
-------
torch.Tensor
Updated recurrent state.
"""
next_state = self.state_contraction(self.dA, state) + self.input_contraction(
self.dB, u
)
return next_state
[docs]
def setup_step(self, mode="dense"):
"""Set up dA, dB, dC discretized parameters for stepping"""
self._setup_state()
# Calculate original C
dA_L = power(self.L, self.dA)
Id = torch.eye(self.dA.size(-1)).to(dA_L)
C = _conj(_r2c(self.C)) # (H C N)
dC = torch.linalg.solve(
Id - dA_L.transpose(-1, -2),
C.unsqueeze(-1),
).squeeze(-1)
self.dC = dC
# Do special preprocessing for different step modes
self._step_mode = mode
if mode == "linear":
# Linear case: special step function for the state, we need to handle output
# use conjugate symmetry by default, which affects the output projection
self.dC = 2 * self.dC[:, :, : self.N]
elif mode == "diagonal":
# Eigendecomposition of the A matrix
L, V = torch.linalg.eig(self.dA)
V_inv = torch.linalg.inv(V)
# Check that the eigendedecomposition is correct
if self.verbose:
print(
"Diagonalization error:",
torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA),
)
# Change the parameterization to diagonalize
self.dA = L
self.dB = oe.contract("h n m, h m -> h n", V_inv, self.dB)
self.dC = oe.contract("h n m, c h n -> c h m", V, self.dC)
elif mode == "dense":
pass
else:
raise NotImplementedError(
"NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}"
)
[docs]
def default_state(self, *batch_shape):
"""Create an initial recurrent state.
Parameters
----------
*batch_shape : int
Batch dimensions of the state.
Returns
-------
torch.Tensor
Zero-initialized recurrent state.
"""
C = _r2c(self.C)
N = C.size(-1)
H = C.size(-2)
# Cache the tensor contractions we will later do, for efficiency
# These are put in this function because they depend on the batch size
if self._step_mode != "linear":
N *= 2
if self._step_mode == "diagonal":
self.state_contraction = oe.contract_expression(
"h n, ... h n -> ... h n",
(H, N),
batch_shape + (H, N),
)
else:
# Dense (quadratic) case: expand all terms
self.state_contraction = oe.contract_expression(
"h m n, ... h n -> ... h m",
(H, N, N),
batch_shape + (H, N),
)
self.input_contraction = oe.contract_expression(
"h n, ... h -> ... h n",
(H, N), # self.dB.shape
batch_shape + (H,),
)
self.output_contraction = oe.contract_expression(
"c h n, ... h n -> ... c h",
(C.shape[0], H, N), # self.dC.shape
batch_shape + (H, N),
)
state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device)
return state
[docs]
def step(self, u, state):
"""Must have called self.setup_step() and created state with self.default_state() before calling this"""
if self._step_mode == "linear":
new_state = self._step_state_linear(u, state)
else:
new_state = self._step_state(u, state)
y = self.output_contraction(self.dC, new_state)
return y, new_state
[docs]
def register(self, name, tensor, trainable=False, lr=None, wd=None):
"""Register a tensor as a parameter or buffer.
Parameters
----------
name : str
Name used to register the tensor.
tensor : torch.Tensor
Tensor to register.
trainable : bool
If True, register the tensor as a trainable parameter.
Otherwise, register it as a buffer.
lr : float
Optional learning rate metadata.
wd : float
Optional weight decay metadata.
Returns
-------
None
The function registers the tensor in the module.
"""
if trainable:
self.register_parameter(name, nn.Parameter(tensor))
else:
self.register_buffer(name, tensor)
optim = {}
if trainable and lr is not None:
optim["lr"] = lr
if trainable and wd is not None:
optim["weight_decay"] = wd
if len(optim) > 0:
setattr(getattr(self, name), "_optim", optim)
[docs]
class HippoSSKernel(nn.Module):
"""Wrapper around SSKernel that generates A, B, C, dt according to HiPPO arguments.
The SSKernel is expected to support the interface
forward()
default_state()
setup_step()
step()
"""
def __init__(
self,
H,
N=64,
L=1,
measure="legs",
rank=1,
channels=1,
dt_min=0.001,
dt_max=0.1,
trainable=None,
lr=None,
# Multiply by I-A|^L after initialization; can be turned off for initialization speed
length_correction=True,
hurwitz=False,
tie_state=False,
precision=1,
resample=False,
verbose=False,
):
"""Initialize the HiPPO state-space kernel.
Parameters
----------
H : int
Number of hidden features.
N : int
State dimension.
L : int
Maximum sequence length.
measure : str
HiPPO measure used to construct the state matrix.
rank : int
Rank of the low-rank correction.
channels : int
Number of output channels or heads.
dt_min : float
Minimum discretization step size.
dt_max : float
Maximum discretization step size.
trainable : bool or dict
Controls which kernel parameters are trainable.
lr : float
Optional learning rate metadata for HiPPO parameters.
length_correction : bool
If True, apply length correction during kernel initialization.
hurwitz : bool
If True, enforce negative real parts in the state matrix.
tie_state : bool
If True, tie state parameters across hidden features.
precision : int
Numerical precision. Use 1 for single precision and 2 for double precision.
resample : bool
If True, allow resampling for inputs with different lengths.
verbose : bool
If True, print diagnostic information.
"""
super().__init__()
self.N = N
self.H = H
L = L or 1
self.precision = precision
dtype = torch.double if self.precision == 2 else torch.float
cdtype = torch.cfloat if dtype == torch.float else torch.cdouble
self.rate = None if resample else 1.0
self.channels = channels
# Generate dt
log_dt = torch.rand(self.H, dtype=dtype) * (
math.log(dt_max) - math.log(dt_min)
) + math.log(dt_min)
w, p, B, _ = nplr(measure, self.N, rank, dtype=dtype)
C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype)
self.kernel = SSKernelNPLR(
L,
w,
p,
B,
C,
log_dt,
hurwitz=hurwitz,
trainable=trainable,
lr=lr,
tie_state=tie_state,
length_correction=length_correction,
verbose=verbose,
)
[docs]
def forward(self, L=None, rate=1.0):
k, _ = self.kernel(rate=rate, L=L)
return k.float()
[docs]
def step(self, u, state, **kwargs):
u, state = self.kernel.step(u, state, **kwargs)
return u.float(), state
[docs]
def default_state(self, *args, **kwargs):
return self.kernel.default_state(*args, **kwargs)
[docs]
class S4(nn.Module):
"""Structured State Space Sequence layer."""
def __init__(
self,
d_model,
d_state=64,
l_max=1,
channels=1,
bidirectional=False,
# Arguments for FF
activation="gelu",
postact=None,
initializer=None,
weight_norm=False,
hyper_act=None,
dropout=0.0,
transposed=True,
verbose=False,
**kernel_args,
):
"""Initialize the S4 layer.
Parameters
----------
d_model : int
Hidden feature dimension.
d_state : int
State dimension.
l_max : int
Maximum sequence length.
channels : int
Number of state-space output channels.
bidirectional : bool
If True, use a bidirectional convolution kernel.
activation : str
Activation function applied after the state-space convolution.
postact : str
Activation function applied after the output linear layer.
initializer : str
Initializer used for the output linear layer.
weight_norm : bool
If True, apply weight normalization to the output linear layer.
hyper_act : str
Optional activation for hypernetwork-style multiplicative modulation.
dropout : float
Dropout probability.
transposed : bool
If True, inputs use shape (batch_size, d_model, length).
If False, inputs use shape (batch_size, length, d_model).
verbose : bool
If True, print diagnostic information.
**kernel_args
Additional keyword arguments passed to HippoSSKernel.
"""
super().__init__()
self.h = d_model
self.n = d_state
self.bidirectional = bidirectional
self.channels = channels
self.transposed = transposed
# optional multiplicative modulation GLU-style
# https://arxiv.org/abs/2002.05202
self.hyper = hyper_act is not None
if self.hyper:
channels *= 2
self.hyper_activation = Activation(hyper_act)
self.D = nn.Parameter(torch.randn(channels, self.h))
if self.bidirectional:
channels *= 2
# SSM Kernel
self.kernel = HippoSSKernel(
self.h, N=self.n, L=l_max, channels=channels, verbose=verbose, **kernel_args
)
# Pointwise
self.activation = Activation(activation)
dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout
self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()
# position-wise output transform to mix features
self.output_linear = LinearActivation(
self.h * self.channels,
self.h,
transposed=self.transposed,
initializer=initializer,
activation=postact,
activate=True,
weight_norm=weight_norm,
)
[docs]
def forward(self, u, rate=1.0, **kwargs):
"""Run a forward pass through the S4 layer.
Parameters
----------
u : torch.Tensor
Input tensor. If transposed is True, the shape is
(batch_size, d_model, sequence_length). Otherwise, the shape is
(batch_size, sequence_length, d_model).
rate : float
Sampling rate factor.
**kwargs
Additional keyword arguments kept for compatibility.
Returns
-------
tuple
Tuple containing the output tensor and None for compatibility with
recurrent interfaces.
"""
if not self.transposed:
u = u.transpose(-1, -2)
L = u.size(-1)
# Compute SS Kernel
k = self.kernel(L=L, rate=rate) # (C H L) (B C H L)
# Convolution
if self.bidirectional:
k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2)
k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0))
k_f = torch.fft.rfft(k, n=2 * L) # (C H L)
u_f = torch.fft.rfft(u, n=2 * L) # (B H L)
# k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L)
y_f = oe.contract("bhl,chl->bchl", u_f, k_f)
y = torch.fft.irfft(y_f, n=2 * L)[..., :L] # (B C H L)
# Compute D term in state space equation - essentially a skip connection
# u.unsqueeze(-3) * self.D.unsqueeze(-1)
y = y + oe.contract("bhl,ch->bchl", u, self.D)
# Optional hyper-network multiplication
if self.hyper:
y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2)
y = self.hyper_activation(yh) * y
# Reshape to flatten channels
y = rearrange(y, "... c h l -> ... (c h) l")
y = self.dropout(self.activation(y))
if not self.transposed:
y = y.transpose(-1, -2)
y = self.output_linear(y)
return y, None
[docs]
def step(self, u, state):
"""Step one time step as a recurrent model. Intended to be used during validation.
u: (B H)
state: (B H N)
Returns: output (B H), state (B H N)
"""
assert not self.training
y, next_state = self.kernel.step(u, state) # (B C H)
y = y + u.unsqueeze(-2) * self.D
y = rearrange(y, "... c h -> ... (c h)")
y = self.activation(y)
if self.transposed:
y = self.output_linear(y.unsqueeze(-1)).squeeze(-1)
else:
y = self.output_linear(y)
return y, next_state
[docs]
def default_state(self, *batch_shape, device=None):
return self.kernel.default_state(*batch_shape)
@property
def d_state(self):
return self.h * self.n
@property
def d_output(self):
return self.h
@property
def state_to_tensor(self):
return lambda state: rearrange("... h n -> ... (h n)", state)