Source code for qumphy.models.s42

"""
File: qumphy/models/s42.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: Standalone version of Structured (Sequence) State Space (S4) model.
"""

# version from the SSM ECG repo
# https://github.com/HazyResearch/state-spaces/blob/main/src/models/sequence/ss/standalone/s4.py

from functools import partial
import math
import numpy as np
from scipy import special as ss
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import opt_einsum as oe
from pykeops.torch import Genred


""" Cauchy kernel """

try:  # Try CUDA extension
    from extensions.cauchy.cauchy import cauchy_mult

    has_cauchy_extension = True
except ImportError:
    # log.warn(
    #     "CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%"
    # )
    has_cauchy_extension = False


def _broadcast_dims(*tensors):
    """Broadcast tensors by adding leading singleton dimensions.

    Parameters
    ----------
    *tensors : torch.Tensor
        Tensors with possibly different numbers of dimensions.

    Returns
    -------
    list
        List of reshaped tensors with the same number of dimensions.
    """
    max_dim = max([len(tensor.shape) for tensor in tensors])
    tensors = [
        tensor.view((1,) * (max_dim - len(tensor.shape)) + tensor.shape)
        for tensor in tensors
    ]
    return tensors


_c2r = torch.view_as_real
_r2c = torch.view_as_complex


def _conj(x):
    """Broadcast tensors by adding leading singleton dimensions.

    Parameters
    ----------
    *tensors : torch.Tensor
        Tensors with possibly different numbers of dimensions.

    Returns
    -------
    list
        List of reshaped tensors with the same number of dimensions.
    """
    return torch.cat([x, x.conj()], dim=-1)


if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10):

    def _resolve_conj(x):
        """Concatenate a complex tensor with its conjugate.

        Parameters
        ----------
        x : torch.Tensor
            Complex-valued input tensor.

        Returns
        -------
        torch.Tensor
            Tensor containing the input and its complex conjugate along the last
            dimension.
        """
        return x.conj().resolve_conj()

else:

    def _resolve_conj(x):
        """Return the resolved complex conjugate of a tensor.

        Parameters
        ----------
        x : torch.Tensor
            Complex-valued input tensor.

        Returns
        -------
        torch.Tensor
            Resolved conjugate tensor.
        """
        return x.conj()


[docs] def cauchy_conj(v, z, w): """Compute the Cauchy multiplication using PyKeOps. Parameters ---------- v : torch.Tensor Numerator tensor. z : torch.Tensor Evaluation points. w : torch.Tensor Complex poles. Returns ------- torch.Tensor Result of the Cauchy multiplication. """ expr_num = "z * ComplexReal(v) - Real2Complex(Sum(v * w))" expr_denom = "ComplexMult(z-w, z-Conj(w))" cauchy_mult = Genred( f"ComplexDivide({expr_num}, {expr_denom})", [ "v = Vj(2)", "z = Vi(2)", "w = Vj(2)", ], reduction_op="Sum", axis=1, ) v, z, w = _broadcast_dims(v, z, w) v = _c2r(v) z = _c2r(z) w = _c2r(w) r = 2 * cauchy_mult(v, z, w, backend="GPU") return _r2c(r)
""" simple nn.Module components """
[docs] def Activation(activation=None, dim=-1): """Compute the Cauchy multiplication using PyKeOps. Parameters ---------- v : torch.Tensor Numerator tensor. z : torch.Tensor Evaluation points. w : torch.Tensor Complex poles. Returns ------- torch.Tensor Result of the Cauchy multiplication. """ if activation in [None, "id", "identity", "linear"]: return nn.Identity() elif activation == "tanh": return nn.Tanh() elif activation == "relu": return nn.ReLU() elif activation == "gelu": return nn.GELU() elif activation in ["swish", "silu"]: return nn.SiLU() elif activation == "glu": return nn.GLU(dim=dim) elif activation == "sigmoid": return nn.Sigmoid() else: raise NotImplementedError( "hidden activation '{}' is not implemented".format(activation) )
[docs] def get_initializer(name, activation=None): """Get a weight initialization function. Parameters ---------- name : str Name of the initializer. Supported values are "uniform", "normal", "xavier", "zero", and "one". activation : str Activation function used to determine the initialization gain. Returns ------- callable Weight initialization function. """ if activation in [None, "id", "identity", "linear", "modrelu"]: nonlinearity = "linear" elif activation in ["relu", "tanh", "sigmoid"]: nonlinearity = activation elif activation in ["gelu", "swish", "silu"]: nonlinearity = "relu" # Close to ReLU so approximate with ReLU's gain else: raise NotImplementedError( f"get_initializer: activation {activation} not supported" ) if name == "uniform": initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity) elif name == "normal": initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity) elif name == "xavier": initializer = torch.nn.init.xavier_normal_ elif name == "zero": initializer = partial(torch.nn.init.constant_, val=0) elif name == "one": initializer = partial(torch.nn.init.constant_, val=1) else: raise NotImplementedError( f"get_initializer: initializer type {name} not supported" ) return initializer
[docs] class TransposedLinear(nn.Module): """Linear module on the second-to-last dimension""" def __init__(self, d_input, d_output, bias=True): """Initialize the transposed linear layer. Parameters ---------- d_input : int Number of input features. d_output : int Number of output features. bias : bool If True, include a bias parameter. """ super().__init__() self.weight = nn.Parameter(torch.empty(d_output, d_input)) # nn.Linear default init nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent if bias: self.bias = nn.Parameter(torch.empty(d_output, 1)) bound = 1 / math.sqrt(d_input) nn.init.uniform_(self.bias, -bound, bound) else: self.bias = 0.0
[docs] def forward(self, x): """Run a forward pass through the transposed linear layer. Parameters ---------- x : torch.Tensor Input tensor. Returns ------- torch.Tensor Output tensor after applying the linear transformation. """ return oe.contract("... u l, v u -> ... v l", x, self.weight) + self.bias
[docs] def LinearActivation( d_input, d_output, bias=True, zero_bias_init=False, transposed=False, initializer=None, activation=None, activate=False, # Apply activation as part of this module weight_norm=False, **kwargs, ): """Create a linear layer with optional initialization and activation. Parameters ---------- d_input : int Number of input features. d_output : int Number of output features. bias : bool If True, include a bias parameter. zero_bias_init : bool If True, initialize the bias with zeros. transposed : bool If True, use TransposedLinear instead of nn.Linear. initializer : str Name of the weight initializer. activation : str Name of the activation function. activate : bool If True, append the activation function to the linear layer. weight_norm : bool If True, apply weight normalization to the linear layer. **kwargs Additional keyword arguments passed to the linear layer. Returns ------- nn.Module Linear module, optionally followed by an activation function. """ # Construct core module linear_cls = TransposedLinear if transposed else nn.Linear if activation == "glu": d_output *= 2 linear = linear_cls(d_input, d_output, bias=bias, **kwargs) # Initialize weight if initializer is not None: get_initializer(initializer, activation)(linear.weight) # Initialize bias if bias and zero_bias_init: nn.init.zeros_(linear.bias) # Weight norm if weight_norm: linear = nn.utils.weight_norm(linear) if activate and activation is not None: activation = Activation(activation, dim=-2 if transposed else -1) linear = nn.Sequential(linear, activation) return linear
""" Misc functional utilities """
[docs] def krylov(L, A, b, c=None, return_power=False): """Compute a Krylov sequence. Parameters ---------- L : int Length of the Krylov sequence. A : torch.Tensor Square transition matrix. b : torch.Tensor Initial vector. c : torch.Tensor Optional projection vector. return_power : bool If True, also return A raised to the power L - 1. Returns ------- torch.Tensor or tuple Krylov sequence, optionally together with A raised to the power L - 1. """ # TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises x = b.unsqueeze(-1) # (..., N, 1) A_ = A AL = None if return_power: AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device) _L = L - 1 done = L == 1 # loop invariant: _L represents how many indices left to compute while not done: if return_power: if _L % 2 == 1: AL = A_ @ AL _L //= 2 # Save memory on last iteration length = x.shape[-1] if L - length <= length: done = True _x = x[..., : L - length] else: _x = x _x = A_ @ _x # there might be a more efficient way of ordering axes x = torch.cat([x, _x], dim=-1) if not done: A_ = A_ @ A_ assert x.shape[-1] == L if c is not None: x = torch.einsum("...nl, ...n -> ...l", x, c) x = x.contiguous() # WOW!! if return_power: return x, AL else: return x
[docs] def power(L, A, v=None): """Compute a matrix power and optional scan reduction. Parameters ---------- L : int Power to which the matrix is raised. A : torch.Tensor Square matrix of shape (..., N, N). v : torch.Tensor Optional tensor used to compute the scan sum over powers of A. Returns ------- torch.Tensor or tuple If v is None, returns A raised to the power L. Otherwise, returns the matrix power and the scan reduction. """ Id = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) powers = [A] largest_pow = 1 while True: if L % 2 == 1: Id = powers[-1] @ Id L //= 2 if L == 0: break largest_pow *= 2 powers.append(powers[-1] @ powers[-1]) if v is None: return Id # Invariants: # powers[-1] := A^l # l := largest po2 at most L # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A # We do this reverse divide-and-conquer for efficiency reasons: # 1) it involves fewer padding steps for non-po2 L # 2) it involves more contiguous arrays # Take care of edge case for non-po2 arrays # Note that this initial step is a no-op for the case of power of 2 (l == L) k = v.size(-1) - largest_pow v_ = powers.pop() @ v[..., largest_pow:] v = v[..., :largest_pow] v[..., :k] = v[..., :k] + v_ # Handle reduction for power of 2 while v.size(-1) > 1: v = rearrange(v, "... (z l) -> ... z l", z=2) v = v[..., 0, :] + powers.pop() @ v[..., 1, :] return Id, v.squeeze(-1)
""" HiPPO utilities """
[docs] def embed_c2r(A): A = rearrange(A, "... m n -> ... m () n ()") A = np.pad(A, ((0, 0), (0, 1), (0, 0), (0, 1))) + np.pad( A, ((0, 0), (1, 0), (0, 0), (1, 0)) ) return rearrange(A, "m x n y -> (m x) (n y)")
[docs] def transition(measure, N, **measure_args): """A, B transition matrices for different measures measure: the type of measure legt - Legendre (translated) legs - Legendre (scaled) glagt - generalized Laguerre (translated) lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization Parameters ---------- measure : str Type of HiPPO measure. Supported values include "legt", "legs", "glagt", "lagt", "fourier", "random", and "diagonal". N : int State dimension. **measure_args Additional arguments for the selected measure. Returns ------- tuple Tuple containing the transition matrix A and input matrix B. """ # Laguerre (translated) if measure == "lagt": b = measure_args.get("beta", 1.0) A = np.eye(N) / 2 - np.tril(np.ones((N, N))) B = b * np.ones((N, 1)) # Generalized Laguerre # alpha 0, beta small is most stable (limits to the 'lagt' measure) # alpha 0, beta 1 has transition matrix A = [lower triangular 1] elif measure == "glagt": alpha = measure_args.get("alpha", 0.0) beta = measure_args.get("beta", 0.01) A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1) B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None] L = np.exp( 0.5 * (ss.gammaln(np.arange(N) + alpha + 1) - ss.gammaln(np.arange(N) + 1)) ) A = (1.0 / L[:, None]) * A * L[None, :] B = ( (1.0 / L[:, None]) * B * np.exp(-0.5 * ss.gammaln(1 - alpha)) * beta ** ((1 - alpha) / 2) ) # Legendre (translated) elif measure == "legt": Q = np.arange(N, dtype=np.float64) R = (2 * Q + 1) ** 0.5 j, i = np.meshgrid(Q, Q) A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :] B = R[:, None] A = -A # Legendre (scaled) elif measure == "legs": q = np.arange(N, dtype=np.float64) col, row = np.meshgrid(q, q) r = 2 * q + 1 M = -(np.where(row >= col, r, 0) - np.diag(q)) T = np.sqrt(np.diag(2 * q + 1)) A = T @ M @ np.linalg.inv(T) B = np.diag(T)[:, None] B = ( B.copy() ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) elif measure == "fourier": freqs = np.arange(N // 2) d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1] A = 2 * np.pi * (np.diag(d, 1) - np.diag(d, -1)) A = A - embed_c2r(np.ones((N // 2, N // 2))) B = embed_c2r(np.ones((N // 2, 1)))[..., :1] elif measure == "random": A = np.random.randn(N, N) / N B = np.random.randn(N, 1) elif measure == "diagonal": A = -np.diag(np.exp(np.random.randn(N))) B = np.random.randn(N, 1) else: raise NotImplementedError return A, B
[docs] def rank_correction(measure, N, rank=1, dtype=torch.float): """Return low-rank matrix L such that A + L is normal Parameters ---------- measure : str Type of HiPPO measure. N : int State dimension. rank : int Rank of the correction matrix. dtype : torch.dtype Data type of the returned tensor. Returns ------- torch.Tensor Low-rank correction matrix. """ if measure == "legs": assert rank >= 1 P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N) elif measure == "legt": assert rank >= 2 P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N) P0 = P.clone() P0[0::2] = 0.0 P1 = P.clone() P1[1::2] = 0.0 P = torch.stack([P0, P1], dim=0) # (2 N) elif measure == "lagt": assert rank >= 1 P = 0.5**0.5 * torch.ones(1, N, dtype=dtype) elif measure == "fourier": P = torch.ones(N, dtype=dtype) # (N) P0 = P.clone() P0[0::2] = 0.0 P1 = P.clone() P1[1::2] = 0.0 P = torch.stack([P0, P1], dim=0) # (2 N) else: raise NotImplementedError d = P.size(0) if rank > d: P = torch.cat([P, torch.zeros(rank - d, N, dtype=dtype)], dim=0) # (rank N) return P
[docs] def nplr(measure, N, rank=1, dtype=torch.float): """Convert a HiPPO matrix into normal plus low-rank form. Parameters ---------- measure : str Type of HiPPO measure. N : int State dimension. rank : int Rank of the low-rank correction. dtype : torch.dtype Floating point data type. Returns ------- tuple Tuple containing eigenvalues, low-rank correction, input vector, and eigenvector matrix. """ assert dtype == torch.float or torch.cfloat if measure == "random": dtype = torch.cfloat if dtype == torch.float else torch.cdouble # w = torch.randn(N//2, dtype=dtype) w = -torch.exp(torch.randn(N // 2)) + 1j * torch.randn(N // 2) P = torch.randn(rank, N // 2, dtype=dtype) B = torch.randn(N // 2, dtype=dtype) V = torch.eye(N, dtype=dtype)[..., : N // 2] # Only used in testing return w, P, B, V A, B = transition(measure, N) A = torch.as_tensor(A, dtype=dtype) # (N, N) B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) P = rank_correction(measure, N, rank=rank, dtype=dtype) AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3) w, V = torch.linalg.eig(AP) # (..., N) (..., N, N) # V w V^{-1} = A # Only keep one of the conjugate pairs w = w[..., 0::2].contiguous() V = V[..., 0::2].contiguous() V_inv = V.conj().transpose(-1, -2) B = oe.contract("ij, j -> i", V_inv, B.to(V)) # V^* B P = oe.contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P return w, P, B, V
[docs] def bilinear(dt, A, B=None): """Apply bilinear discretization to a continuous state-space system. Parameters ---------- dt : torch.Tensor Time step or timescale tensor. A : torch.Tensor Continuous transition matrix. B : torch.Tensor Optional continuous input matrix. Returns ------- tuple Discretized transition matrix and discretized input matrix. """ N = A.shape[-1] Id = torch.eye(N).to(A) A_backwards = Id - dt[:, None, None] / 2 * A A_forwards = Id + dt[:, None, None] / 2 * A if B is None: dB = None else: dB = dt[..., None] * torch.linalg.solve(A_backwards, B.unsqueeze(-1)).squeeze( -1 ) # (... N) dA = torch.linalg.solve(A_backwards, A_forwards) # (... N N) return dA, dB
[docs] class SSKernelNPLR(nn.Module): """Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR) The class name stands for 'State-Space SSKernel for Normal Plus Low-Rank'. The parameters of this function are as follows. A: (... N N) the state matrix B: (... N) input matrix C: (... N) output matrix dt: (...) timescales / discretization step size p, q: (... P N) low-rank correction to A, such that Ap=A+pq^T is a normal matrix The forward pass of this Module returns: (... L) that represents represents FFT SSKernel_L(A^dt, B^dt, C) """ @torch.no_grad() def _setup_C(self, double_length=False): """Construct C~ from C double_length: current C is for length L, convert it to length 2L """ C = _r2c(self.C) self._setup_state() dA_L = power(self.L, self.dA) # Multiply C by I - dA_L C_ = _conj(C) prod = oe.contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_) if double_length: prod = -prod # Multiply by I + dA_L instead C_ = C_ - prod C_ = C_[..., : self.N] # Take conjugate pairs again self.C.copy_(_c2r(C_)) if double_length: self.L *= 2 self._omega(self.L, dtype=C.dtype, device=C.device, cache=True) def _omega(self, L, dtype, device, cache=True): """Calculate (and cache) FFT nodes and their "unprocessed" them with the bilinear transform This should be called everytime the internal length self.L changes Parameters ---------- L : int Kernel length. dtype : torch.dtype Data type of the nodes. device : torch.device Device on which the nodes are created. cache : bool If True, register the nodes as buffers. Returns ------- tuple Tuple containing omega and z nodes. """ omega = torch.tensor( np.exp(-2j * np.pi / (L)), dtype=dtype, device=device ) # \omega_{2L} omega = omega ** torch.arange(0, L // 2 + 1, device=device) z = 2 * (1 - omega) / (1 + omega) if cache: self.register_buffer("omega", _c2r(omega)) self.register_buffer("z", _c2r(z)) return omega, z def __init__( self, L, w, P, B, C, log_dt, hurwitz=False, trainable=None, lr=None, tie_state=False, length_correction=True, verbose=False, ): """ L: Maximum length; this module computes an SSM kernel of length L w: (N) p: (r, N) low-rank correction to A q: (r, N) A represented by diag(w) - pq^* B: (N) dt: (H) timescale per feature C: (H, C, N) system is 1-D to c-D (channels) hurwitz: tie pq and ensure w has negative real part trainable: toggle which of the parameters is trainable lr: add hook to set lr of hippo parameters specially (everything besides C) tie_state: tie all state parameters across the H hidden features length_correction: multiply C by (I - dA^L) - can be turned off when L is large for slight speedup at initialization (only relevant when N large as well) Note: tensor shape N here denotes half the true state size, because of conjugate symmetry """ super().__init__() self.hurwitz = hurwitz self.tie_state = tie_state self.verbose = verbose # Rank of low-rank correction self.rank = P.shape[-2] assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1) self.H = log_dt.size(-1) self.N = w.size(-1) # Broadcast everything to correct shapes C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (H, C, N) H = 1 if self.tie_state else self.H B = repeat(B, "n -> 1 h n", h=H) P = repeat(P, "r n -> r h n", h=H) w = repeat(w, "n -> h n", h=H) # Cache Fourier nodes every time we set up a desired length self.L = L if self.L is not None: self._omega(self.L, dtype=C.dtype, device=C.device, cache=True) # Register parameters # C is a regular parameter, not state # self.C = nn.Parameter(_c2r(C.conj().resolve_conj())) self.C = nn.Parameter(_c2r(_resolve_conj(C))) train = False if trainable is None: trainable = {} elif not trainable: trainable = {} elif trainable: trainable, train = {}, True self.register("log_dt", log_dt, trainable.get("dt", train), lr, 0.0) self.register("B", _c2r(B), trainable.get("B", train), lr, 0.0) self.register("P", _c2r(P), trainable.get("P", train), lr, 0.0) if self.hurwitz: # Some of the HiPPO methods have real part 0 log_w_real = torch.log(-w.real + 1e-3) w_imag = w.imag self.register("log_w_real", log_w_real, trainable.get("A", 0), lr, 0.0) self.register("w_imag", w_imag, trainable.get("A", train), lr, 0.0) self.Q = None else: self.register("w", _c2r(w), trainable.get("A", train), lr, 0.0) # self.register("Q", _c2r(P.clone().conj().resolve_conj()), trainable.get('P', train), lr, 0.0) Q = _resolve_conj(P.clone()) self.register("Q", _c2r(Q), trainable.get("P", train), lr, 0.0) if length_correction: self._setup_C() def _w(self): # Get the internal w (diagonal) parameter if self.hurwitz: w_real = -torch.exp(self.log_w_real) w_imag = self.w_imag w = w_real + 1j * w_imag else: w = _r2c(self.w) # (..., N) return w
[docs] def forward(self, state=None, rate=1.0, L=None): """ state: (..., s, N) extra tensor that augments B rate: sampling rate factor returns: (..., c+s, L) """ # Handle sampling rate logic # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) sampling rate rate # If either are not passed in, assume we're not asked to change the scale of our kernel assert not (rate is None and L is None) if rate is None: rate = self.L / L if L is None: L = int(self.L / rate) # Increase the internal length if needed while rate * L > self.L: self.double_length() dt = torch.exp(self.log_dt) * rate B = _r2c(self.B) C = _r2c(self.C) P = _r2c(self.P) Q = P.conj() if self.Q is None else _r2c(self.Q) w = self._w() if rate == 1.0: # Use cached FFT nodes omega, z = _r2c(self.omega), _r2c(self.z) # (..., L) else: omega, z = self._omega( int(self.L / rate), dtype=w.dtype, device=w.device, cache=False ) if self.tie_state: B = repeat(B, "... 1 n -> ... h n", h=self.H) P = repeat(P, "... 1 n -> ... h n", h=self.H) Q = repeat(Q, "... 1 n -> ... h n", h=self.H) # Augment B if state is not None: # Have to "unbilinear" the state to put it into the same "type" as B # Compute 1/dt * (I + dt/2 A) @ state # Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way s = _conj(state) if state.size(-1) == self.N else state # (B H N) sA = s * _conj(w) - oe.contract( # (B H N) "bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P) ) s = s / dt.unsqueeze(-1) + sA / 2 s = s[..., : self.N] B = torch.cat([s, B], dim=-3) # (s+1, H, N) # Incorporate dt into A w = w * dt.unsqueeze(-1) # (H N) # Stack B and p, C and q for convenient batching B = torch.cat([B, P], dim=-3) # (s+1+r, H, N) C = torch.cat([C, Q], dim=-3) # (c+r, H, N) # Incorporate B and C batch dimensions v = B.unsqueeze(-3) * C.unsqueeze(-4) # (s+1+r, c+r, H, N) # w = w[None, None, ...] # (1, 1, H, N) # z = z[None, None, None, ...] # (1, 1, 1, L) # Calculate resolvent at omega if has_cauchy_extension and z.dtype == torch.cfloat: r = cauchy_mult(v, z, w, symmetric=True) else: r = cauchy_conj(v, z, w) r = r * dt[None, None, :, None] # (S+1+R, C+R, H, L) # Low-rank Woodbury correction if self.rank == 1: k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / ( 1 + r[-1:, -1:, :, :] ) elif self.rank == 2: r00 = r[: -self.rank, : -self.rank, :, :] r01 = r[: -self.rank, -self.rank :, :, :] r10 = r[-self.rank :, : -self.rank, :, :] r11 = r[-self.rank :, -self.rank :, :, :] det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[ :1, 1:, :, : ] * r11[1:, :1, :, :] s = ( r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] ) s = s / det k_f = r00 - s else: r00 = r[: -self.rank, : -self.rank, :, :] r01 = r[: -self.rank, -self.rank :, :, :] r10 = r[-self.rank :, : -self.rank, :, :] r11 = r[-self.rank :, -self.rank :, :, :] r11 = rearrange(r11, "a b h n -> h n a b") r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11) r11 = rearrange(r11, "h n a b -> a b h n") k_f = r00 - torch.einsum( "i j h n, j k h n, k l h n -> i l h n", r01, r11, r10 ) # Final correction for the bilinear transform k_f = k_f * 2 / (1 + omega) # Move from frequency to coefficients k = torch.fft.irfft(k_f) # (S+1, C, H, L) # Truncate to target length k = k[..., :L] if state is not None: k_state = k[:-1, :, :, :] # (S, C, H, L) else: k_state = None k_B = k[-1, :, :, :] # (C H L) return k_B, k_state
@torch.no_grad() def double_length(self): """Double the internal kernel length. Returns ------- None The function updates the internal length and cached FFT nodes. """ self._setup_C(double_length=True) def _setup_linear(self): """Set up parameters for fast linear recurrent stepping. Returns ------- None The function stores stepping parameters in self.step_params. """ w = self._w() B = _r2c(self.B) # (H N) P = _r2c(self.P) Q = P.conj() if self.Q is None else _r2c(self.Q) # Prepare Linear stepping dt = torch.exp(self.log_dt) D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N) R = ( torch.eye(self.rank, dtype=w.dtype, device=w.device) + 2 * oe.contract("r h n, h n, s h n -> h r s", Q, D, P).real ) # (H r r) Q_D = rearrange(Q * D, "r h n -> h r n") R = torch.linalg.solve(R.to(Q_D), Q_D) # (H r N) R = rearrange(R, "h r n -> r h n") self.step_params = { "D": D, # (H N) "R": R, # (r H N) "P": P, # (r H N) "Q": Q, # (r H N) "B": B, # (1 H N) "E": 2.0 / dt.unsqueeze(-1) + w, # (H N) } def _step_state_linear(self, u=None, state=None): """ Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization. Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster u: (H) input state: (H, N/2) state with conjugate pairs Optionally, the state can have last dimension N Returns: same shape as state """ C = _r2c(self.C) # View used for dtype/device if u is None: # Special case used to find dA u = torch.zeros(self.H, dtype=C.dtype, device=C.device) if state is None: # Special case used to find dB state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device) step_params = self.step_params.copy() # Only store half of the conjugate pairs; should be true by default if state.size(-1) == self.N: # There should be a slightly faster way using conjugate symmetry def contract_fn(p, x, y): return oe.contract( "r h n, r h m, ... h m -> ... h n", _conj(p), _conj(x), _conj(y) )[ ..., : self.N ] # inner outer product else: assert state.size(-1) == 2 * self.N step_params = {k: _conj(v) for k, v in step_params.items()} # TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping def contract_fn(p, x, y): return oe.contract( "r h n, r h m, ... h m -> ... h n", p, x, y ) # inner outer product D = step_params["D"] # (H N) E = step_params["E"] # (H N) R = step_params["R"] # (r H N) P = step_params["P"] # (r H N) Q = step_params["Q"] # (r H N) B = step_params["B"] # (1 H N) new_state = E * state - contract_fn(P, Q, state) # (B H N) new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N) new_state = D * (new_state - contract_fn(P, R, new_state)) return new_state def _setup_state(self): """Construct dA and dB for discretized state equation""" # Construct dA and dB by using the stepping self._setup_linear() # Just returns a view that we use for finding dtype/device C = _r2c(self.C) state = torch.eye(2 * self.N, dtype=C.dtype, device=C.device).unsqueeze( -2 ) # (N 1 N) dA = self._step_state_linear(state=state) dA = rearrange(dA, "n h m -> h m n") self.dA = dA # (H N N) u = C.new_ones(self.H) dB = self._step_state_linear(u=u) dB = _conj(dB) self.dB = rearrange(dB, "1 h n -> h n") # (H N) def _step_state(self, u, state): """Must be called after self.default_state() is used to construct an initial state! Parameters ---------- u : torch.Tensor Input tensor for the current time step. state : torch.Tensor Current recurrent state. Returns ------- torch.Tensor Updated recurrent state. """ next_state = self.state_contraction(self.dA, state) + self.input_contraction( self.dB, u ) return next_state
[docs] def setup_step(self, mode="dense"): """Set up dA, dB, dC discretized parameters for stepping""" self._setup_state() # Calculate original C dA_L = power(self.L, self.dA) Id = torch.eye(self.dA.size(-1)).to(dA_L) C = _conj(_r2c(self.C)) # (H C N) dC = torch.linalg.solve( Id - dA_L.transpose(-1, -2), C.unsqueeze(-1), ).squeeze(-1) self.dC = dC # Do special preprocessing for different step modes self._step_mode = mode if mode == "linear": # Linear case: special step function for the state, we need to handle output # use conjugate symmetry by default, which affects the output projection self.dC = 2 * self.dC[:, :, : self.N] elif mode == "diagonal": # Eigendecomposition of the A matrix L, V = torch.linalg.eig(self.dA) V_inv = torch.linalg.inv(V) # Check that the eigendedecomposition is correct if self.verbose: print( "Diagonalization error:", torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA), ) # Change the parameterization to diagonalize self.dA = L self.dB = oe.contract("h n m, h m -> h n", V_inv, self.dB) self.dC = oe.contract("h n m, c h n -> c h m", V, self.dC) elif mode == "dense": pass else: raise NotImplementedError( "NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}" )
[docs] def default_state(self, *batch_shape): """Create an initial recurrent state. Parameters ---------- *batch_shape : int Batch dimensions of the state. Returns ------- torch.Tensor Zero-initialized recurrent state. """ C = _r2c(self.C) N = C.size(-1) H = C.size(-2) # Cache the tensor contractions we will later do, for efficiency # These are put in this function because they depend on the batch size if self._step_mode != "linear": N *= 2 if self._step_mode == "diagonal": self.state_contraction = oe.contract_expression( "h n, ... h n -> ... h n", (H, N), batch_shape + (H, N), ) else: # Dense (quadratic) case: expand all terms self.state_contraction = oe.contract_expression( "h m n, ... h n -> ... h m", (H, N, N), batch_shape + (H, N), ) self.input_contraction = oe.contract_expression( "h n, ... h -> ... h n", (H, N), # self.dB.shape batch_shape + (H,), ) self.output_contraction = oe.contract_expression( "c h n, ... h n -> ... c h", (C.shape[0], H, N), # self.dC.shape batch_shape + (H, N), ) state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device) return state
[docs] def step(self, u, state): """Must have called self.setup_step() and created state with self.default_state() before calling this""" if self._step_mode == "linear": new_state = self._step_state_linear(u, state) else: new_state = self._step_state(u, state) y = self.output_contraction(self.dC, new_state) return y, new_state
[docs] def register(self, name, tensor, trainable=False, lr=None, wd=None): """Register a tensor as a parameter or buffer. Parameters ---------- name : str Name used to register the tensor. tensor : torch.Tensor Tensor to register. trainable : bool If True, register the tensor as a trainable parameter. Otherwise, register it as a buffer. lr : float Optional learning rate metadata. wd : float Optional weight decay metadata. Returns ------- None The function registers the tensor in the module. """ if trainable: self.register_parameter(name, nn.Parameter(tensor)) else: self.register_buffer(name, tensor) optim = {} if trainable and lr is not None: optim["lr"] = lr if trainable and wd is not None: optim["weight_decay"] = wd if len(optim) > 0: setattr(getattr(self, name), "_optim", optim)
[docs] class HippoSSKernel(nn.Module): """Wrapper around SSKernel that generates A, B, C, dt according to HiPPO arguments. The SSKernel is expected to support the interface forward() default_state() setup_step() step() """ def __init__( self, H, N=64, L=1, measure="legs", rank=1, channels=1, dt_min=0.001, dt_max=0.1, trainable=None, lr=None, # Multiply by I-A|^L after initialization; can be turned off for initialization speed length_correction=True, hurwitz=False, tie_state=False, precision=1, resample=False, verbose=False, ): """Initialize the HiPPO state-space kernel. Parameters ---------- H : int Number of hidden features. N : int State dimension. L : int Maximum sequence length. measure : str HiPPO measure used to construct the state matrix. rank : int Rank of the low-rank correction. channels : int Number of output channels or heads. dt_min : float Minimum discretization step size. dt_max : float Maximum discretization step size. trainable : bool or dict Controls which kernel parameters are trainable. lr : float Optional learning rate metadata for HiPPO parameters. length_correction : bool If True, apply length correction during kernel initialization. hurwitz : bool If True, enforce negative real parts in the state matrix. tie_state : bool If True, tie state parameters across hidden features. precision : int Numerical precision. Use 1 for single precision and 2 for double precision. resample : bool If True, allow resampling for inputs with different lengths. verbose : bool If True, print diagnostic information. """ super().__init__() self.N = N self.H = H L = L or 1 self.precision = precision dtype = torch.double if self.precision == 2 else torch.float cdtype = torch.cfloat if dtype == torch.float else torch.cdouble self.rate = None if resample else 1.0 self.channels = channels # Generate dt log_dt = torch.rand(self.H, dtype=dtype) * ( math.log(dt_max) - math.log(dt_min) ) + math.log(dt_min) w, p, B, _ = nplr(measure, self.N, rank, dtype=dtype) C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) self.kernel = SSKernelNPLR( L, w, p, B, C, log_dt, hurwitz=hurwitz, trainable=trainable, lr=lr, tie_state=tie_state, length_correction=length_correction, verbose=verbose, )
[docs] def forward(self, L=None, rate=1.0): k, _ = self.kernel(rate=rate, L=L) return k.float()
[docs] def step(self, u, state, **kwargs): u, state = self.kernel.step(u, state, **kwargs) return u.float(), state
[docs] def default_state(self, *args, **kwargs): return self.kernel.default_state(*args, **kwargs)
[docs] class S4(nn.Module): """Structured State Space Sequence layer.""" def __init__( self, d_model, d_state=64, l_max=1, channels=1, bidirectional=False, # Arguments for FF activation="gelu", postact=None, initializer=None, weight_norm=False, hyper_act=None, dropout=0.0, transposed=True, verbose=False, **kernel_args, ): """Initialize the S4 layer. Parameters ---------- d_model : int Hidden feature dimension. d_state : int State dimension. l_max : int Maximum sequence length. channels : int Number of state-space output channels. bidirectional : bool If True, use a bidirectional convolution kernel. activation : str Activation function applied after the state-space convolution. postact : str Activation function applied after the output linear layer. initializer : str Initializer used for the output linear layer. weight_norm : bool If True, apply weight normalization to the output linear layer. hyper_act : str Optional activation for hypernetwork-style multiplicative modulation. dropout : float Dropout probability. transposed : bool If True, inputs use shape (batch_size, d_model, length). If False, inputs use shape (batch_size, length, d_model). verbose : bool If True, print diagnostic information. **kernel_args Additional keyword arguments passed to HippoSSKernel. """ super().__init__() self.h = d_model self.n = d_state self.bidirectional = bidirectional self.channels = channels self.transposed = transposed # optional multiplicative modulation GLU-style # https://arxiv.org/abs/2002.05202 self.hyper = hyper_act is not None if self.hyper: channels *= 2 self.hyper_activation = Activation(hyper_act) self.D = nn.Parameter(torch.randn(channels, self.h)) if self.bidirectional: channels *= 2 # SSM Kernel self.kernel = HippoSSKernel( self.h, N=self.n, L=l_max, channels=channels, verbose=verbose, **kernel_args ) # Pointwise self.activation = Activation(activation) dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() # position-wise output transform to mix features self.output_linear = LinearActivation( self.h * self.channels, self.h, transposed=self.transposed, initializer=initializer, activation=postact, activate=True, weight_norm=weight_norm, )
[docs] def forward(self, u, rate=1.0, **kwargs): """Run a forward pass through the S4 layer. Parameters ---------- u : torch.Tensor Input tensor. If transposed is True, the shape is (batch_size, d_model, sequence_length). Otherwise, the shape is (batch_size, sequence_length, d_model). rate : float Sampling rate factor. **kwargs Additional keyword arguments kept for compatibility. Returns ------- tuple Tuple containing the output tensor and None for compatibility with recurrent interfaces. """ if not self.transposed: u = u.transpose(-1, -2) L = u.size(-1) # Compute SS Kernel k = self.kernel(L=L, rate=rate) # (C H L) (B C H L) # Convolution if self.bidirectional: k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2) k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0)) k_f = torch.fft.rfft(k, n=2 * L) # (C H L) u_f = torch.fft.rfft(u, n=2 * L) # (B H L) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L) y_f = oe.contract("bhl,chl->bchl", u_f, k_f) y = torch.fft.irfft(y_f, n=2 * L)[..., :L] # (B C H L) # Compute D term in state space equation - essentially a skip connection # u.unsqueeze(-3) * self.D.unsqueeze(-1) y = y + oe.contract("bhl,ch->bchl", u, self.D) # Optional hyper-network multiplication if self.hyper: y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2) y = self.hyper_activation(yh) * y # Reshape to flatten channels y = rearrange(y, "... c h l -> ... (c h) l") y = self.dropout(self.activation(y)) if not self.transposed: y = y.transpose(-1, -2) y = self.output_linear(y) return y, None
[docs] def step(self, u, state): """Step one time step as a recurrent model. Intended to be used during validation. u: (B H) state: (B H N) Returns: output (B H), state (B H N) """ assert not self.training y, next_state = self.kernel.step(u, state) # (B C H) y = y + u.unsqueeze(-2) * self.D y = rearrange(y, "... c h -> ... (c h)") y = self.activation(y) if self.transposed: y = self.output_linear(y.unsqueeze(-1)).squeeze(-1) else: y = self.output_linear(y) return y, next_state
[docs] def default_state(self, *batch_shape, device=None): return self.kernel.default_state(*batch_shape)
@property def d_state(self): return self.h * self.n @property def d_output(self): return self.h @property def state_to_tensor(self): return lambda state: rearrange("... h n -> ... (h n)", state)