Source code for svise.utils

from abc import ABC, abstractmethod
import math
from typing import Callable, Tuple, Union
from typing import Callable

from gpytorch import kernels
import numpy as np
import scipy.linalg
import torch
from torch import sigmoid
from torch import Tensor
from torch.autograd import backward
import torch.nn as nn
from torch.nn.parameter import Parameter

# from .extern import uwbayes


def finite_difference(func, t, dt):
    """
    Computes the finite difference approximation of the derivative of func at t.
    """
    return (func(t + dt) - func(t - dt)) / (2 * dt)


def skew_sym_duplication_matrix(n: int) -> Tensor:
    """
    Creates the duplication matrix for the skew symmetric matrix of size n
    ie. Dn @ v(A) = vec(A)
    """
    D = torch.zeros(n * (n - 1) // 2, n**2)
    k = 0
    for j in range(n):
        for i in range(n):
            if i > j:
                u = torch.zeros(n * (n - 1) // 2)
                u[k] = 1.0
                T = torch.zeros(n, n)
                T[i, j] = 1.0
                T[j, i] = -1.0
                D += u.unsqueeze(-1) * T.t().ravel().unsqueeze(0)
                k += 1
    return D.t()


def sample_covariance(x: Tensor) -> Tensor:
    """
    Computes the sample covariance of x
    Args:
        x (Tensor): (...,bs, d) batch of inputs
    Returns:
        Tensor: (..., d, d) sample covariance
    """
    N = x.shape[-2]
    mu = x.mean(-2).unsqueeze(-2)
    diff = x - mu
    return 1 / (N - 1) * diff.transpose(-1, -2) @ diff


def make_random_matrix(dim: int, rank: int, random_seed: int = None) -> Tensor:
    """
    Generates a random low rank matrix

    Args:
        dim (int): dimension of matrix
        rank (int): rank of matrix
        random_seed (int, optional): random seed of matrix. Defaults to None.

    Returns:
        Tensor: matrix of rank and dim
    """
    if random_seed is not None:
        torch.manual_seed(random_seed)
    assert dim >= rank, "rank cannot be greater than dim."
    u = torch.randn(rank, dim, 1)
    return (u @ u.transpose(-2, -1)).sum(0).div(rank)


def bjorck(
    A: Tensor,
    tol: float = 1e-9,
    max_iters: int = 50,
    order: int = 1,
    num_iters=None,
) -> Tensor:
    """
    Finds closest orthonormal matrix to A via Bjorck orthonomralization

    Args:
        A (Tensor): Input tensor
        tol (float, optional): Tolerance on ||O - I||_F^2. Defaults to 1e-9.
        max_iters (int, optional): Maximum iterations. Defaults to 50.
        order (int, optional): Order of iteration. Defaults to 1.

    Returns:
        Tensor: Orthonormal matrix closest to A
    """

    N, M = A.shape[-2:]
    scale = math.sqrt(N * M)  # safe bjorck scaling
    I = torch.eye(M)
    Ak = A / scale
    for j in range(max_iters):
        Qk = I - Ak.transpose(-2, -1) @ Ak
        if num_iters is None:
            if Qk.pow(2).sum() < tol:
                break  # check sqr frob. norm is less than tol
        elif j == num_iters:
            break
        if order == 1:
            Ak = Ak @ (I + 1 / 2 * Qk)
        elif order == 2:
            Ak = Ak @ (I + 1 / 2 * Qk + 3 / 8 * (Qk @ Qk))
        else:
            raise ValueError(f"Order: {order} not supported.")
    return Ak, j


def solve_least_squares(A: Tensor, B: Tensor, gamma: float = 1e-6) -> Tensor:
    """
    Solves a least squares problem using SVD for stability.

    Args:
        A (Tensor): (n, m) design matrix
        B (Tensor): (n, d) right hand side

    Returns:
        Tensor: W where A @ W = B
    """
    assert A.dim() == 2, "A must be a matrix"
    assert B.dim() == 2, "B must be a matrix"
    U, S, Vh = torch.linalg.svd(A, full_matrices=False)
    return Vh.t() @ (U.t() @ B).mul((S / (S.pow(2) + gamma)).unsqueeze(-1))


class AffineTransform(ABC, nn.Module):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def forward(self, x: Tensor) -> Tensor:
        """
        Computes x_tilde(t) = A x(t) + b

        Args:
            x (Tensor): (..., d) batch of inputs

        Returns:
            Tensor: (..., d) batch of outputs
        """
        pass

    @abstractmethod
    def transform(self, x: Tensor) -> Tensor:
        """
        Computes x_tilde(t) = A x(t) + b

        Args:
            x (Tensor): (..., d) batch of inputs

        Returns:
            Tensor: (..., d) batch of outputs
        """
        return self.forward(x)

    @abstractmethod
    def inverse(self, x: Tensor) -> Tensor:
        """
        Computes x(t) = A^-1 (x_tilde(t) - b)

        Args:
            x (Tensor): (..., d) batch of inputs

        Returns:
            Tensor: (..., d) batch of outputs
        """
        pass

    @abstractmethod
    def scale(self, dx: Tensor) -> Tensor:
        """
        Computes x_tilde(t) = A x(t). Useful for computing derivative of transform.

        Args:
            dx (Tensor): (..., d) batch of inputs

        Returns:
            Tensor: (..., d) batch of outputs
        """
        pass


class StandardizeTransform(AffineTransform):
    def __init__(self, x: Tensor):
        super().__init__()
        # assert len(x.shape) == 2
        self.register_buffer("mu", x.mean(0))
        self.register_buffer("sigma", x.std(0))

    def forward(self, x: Tensor) -> Tensor:
        return (x - self.mu) / self.sigma

    def transform(self, x: Tensor) -> Tensor:
        return self.forward(x)

    def inverse(self, x: Tensor) -> Tensor:
        return x * self.sigma + self.mu

    def scale(self, dx: Tensor) -> Tensor:
        return dx / self.sigma


class ScaleTransform(nn.Module, ABC):
    """
    Transformation for scaling data by a factor over each dimension
    """

    def __init__(self) -> None:
        super().__init__()

    @property
    @abstractmethod
    def scale(self) -> Tensor:
        """
        Scaling factor

        Returns:
            Tensor: (d,) scaling factor
        """
        pass

    @abstractmethod
    def forward(self, x: Tensor) -> Tensor:
        """
        In general computes x / self.scale

        Args:
            x (Tensor): input

        Returns:
            Tensor: scales x by 1 / self.scale
        """
        pass

    @abstractmethod
    def inverse(self, x: Tensor) -> Tensor:
        """
        Computes the inverse of the transform, in general x * self.scale

        Args:
            x (Tensor): input

        Returns:
            Tensor: scales x by self.scale
        """
        pass


class StdevScaleTransform(ScaleTransform):
    def __init__(self, x: Tensor):
        super().__init__()
        self.register_buffer("sigma", x.std(0))

    @property
    def scale(self) -> Tensor:
        return self.sigma

    def forward(self, x: Tensor) -> Tensor:
        return x / self.scale

    def inverse(self, x: Tensor) -> Tensor:
        return x * self.scale


class IdentityScaleTransform(ScaleTransform):
    def __init__(self, d: int) -> None:
        super().__init__()
        self.register_buffer("ones", torch.ones(d))

    @property
    def scale(self) -> Tensor:
        return self.ones

    def forward(self, x: Tensor) -> Tensor:
        return x

    def inverse(self, x: Tensor) -> Tensor:
        return x


class MultivariateStandardizeTransform(nn.Module):
    def __init__(self, x: Tensor, R: Tensor = None):
        super().__init__()
        self.register_buffer("mu", x.mean(0))
        if R is None:
            R = sample_covariance(x)
        self.register_buffer("L", torch.linalg.cholesky(R))

    def forward(self, x: Tensor) -> Tensor:
        return torch.triangular_solve(
            (x - self.mu).t(), self.L, upper=False
        ).solution.t()

    def inverse(self, x: Tensor) -> Tensor:
        return x @ self.L.t() + self.mu


class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

    def inverse(self, x):
        return x


class Positive(nn.Module):
    def __init__(self, beta=1.0):
        super().__init__()
        self.output_transform = nn.Softplus(beta=beta)
        self.minimum = 1e-6

    def forward(self, x, return_grad=False):
        res = self.output_transform(x) + self.minimum
        if not return_grad:
            return res
        else:
            grad_res = sigmoid(x * self.output_transform.beta)
            return res, grad_res


class ReHU(nn.Module):
    """Rectified Huber unit
    from: https://github.com/locuslab/stable_dynamics/blob/master/models/stabledynamics.py
    """

    def __init__(self, d: float = 1.0):
        super(ReHU, self).__init__()
        self.a = 1 / d
        self.b = -d / 2

    def forward(self, x):
        return torch.max(
            torch.clamp(torch.sign(x) * self.a / 2 * x**2, min=0, max=-self.b),
            x + self.b,
        )


class GLM(nn.Module):
    def __init__(
        self,
        output_size: int,
        a: float,
        b: float,
        n_tau: int,
        kernel: str = "matern52",
        learn_inducing_locations: bool = False,
        output_transform: Callable = Identity(),
    ) -> None:
        super().__init__()
        self.w = Parameter(torch.randn(n_tau, output_size))
        tau = torch.linspace(a, b, n_tau)
        if learn_inducing_locations:
            self.tau = Parameter(tau)
        else:
            self.register_buffer("tau", tau)
        if kernel == "matern52":
            self.K = kernels.MaternKernel(nu=2.5)
        elif kernel == "matern32":
            self.K = kernels.MaternKernel(nu=1.5)
        elif kernel == "rbf":
            self.K = kernels.RBFKernel()
        else:
            raise ValueError(f"Unknown kernel {kernel}")
        self.w = Parameter(
            torch.linalg.cholesky(
                self.K(tau, tau).evaluate().inverse() + 1e-4 * torch.eye(len(tau))
            )
            @ torch.randn(n_tau, output_size)
        )
        self.b = Parameter(torch.zeros(output_size))
        self.output_transform = output_transform

    def forward(self, t: Tensor) -> Tensor:
        if t.dim() == 0:
            t = t.unsqueeze(-1)
        return self.output_transform(self.K(t, self.tau).evaluate() @ self.w + self.b)


class FCNN(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        num_layers: int,
        nonlinearity: Callable = nn.Softplus(),
        output_transform: Callable = Identity(),
        batch_norm: bool = True,
    ) -> None:
        """
        Fully connected neural network with batchnorm

        Args:
            input_size (int): input dimension
            hidden_size (int): number of hidden units in each hidden layer
            output_size (int): output dimension
            num_layers (int): number of hidden layers
            nonlinearity (Callable, optional): nonlinearity at each layer. Defaults to nn.Softplus().
        """
        super(FCNN, self).__init__()
        if batch_norm:
            layers = [
                nn.Linear(input_size, hidden_size),
                nonlinearity,
                nn.BatchNorm1d(hidden_size),
            ]
            for j in range(num_layers - 1):
                layers += [
                    nn.Linear(hidden_size, hidden_size, bias=True),
                    nonlinearity,
                    nn.BatchNorm1d(hidden_size),
                ]
        else:
            layers = [
                nn.Linear(input_size, hidden_size),
                nonlinearity,
            ]
            for j in range(num_layers - 1):
                layers += [
                    nn.Linear(hidden_size, hidden_size),
                    nonlinearity,
                ]
        layers += [nn.Linear(hidden_size, output_size), output_transform]
        self.mlp = nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        if x.dim() == 0:
            return self.mlp(x.reshape(1, 1))
        if x.dim() == 1:
            return self.mlp(x.unsqueeze(-1))
        elif x.dim() == 2:
            return self.mlp(x)
        else:
            bs = x.shape[:-1]
            d = x.shape[-1]
            return self.mlp(x.reshape(-1, d)).reshape(*bs, -1)


def multioutput_gradient(f: Tensor, t: Tensor, vmap=True) -> Tensor:
    """
    Computes the gradient of a f at a batch of times t.

    Args:
        f (Tensor): (bs,d) tensor of function evaluations
        t (Tensor): (bs,) batch of times
        vmap (bool, optional): Bool indicating whether to use vmap. Defaults to True.

    Returns:
        Tensor: (bs,d) gradient of f at t
    """
    d = f.shape[-1]
    basis_vectors = torch.eye(d)
    # looped solution
    if not vmap:
        jacobian_rows = [
            torch.autograd.grad(f.sum(0), t, v, retain_graph=True)[0]
            for v in basis_vectors.unbind()
        ]
        jacobian = torch.stack(jacobian_rows, dim=-1)
    # experiemental, requires torch nightly build
    else:
        torch._C._debug_only_display_vmap_fallback_warnings(True)

        def get_vjp(v):
            return torch.autograd.grad(f.sum(0), t, v, retain_graph=True)[0]

        jacobian = torch.vmap(get_vjp)(basis_vectors).t()
    return jacobian


[docs]class EulerMaruyama(object): """ Performs euler-maruyama integration of an SDE Args: f (Callable): (float, (bs, d)) -> (bs,d) drift function computes a batch of drift values L (Union[Callable, Tensor]): Union[(float, (bs, d)) -> (bs, d, S), (d,S)] diffusion function computes a batch of diffusion values Q (Tensor): (S,S) diffusion matrix of Brownwian motion t_span (Tuple): integration time window x0 (Tensor): (bs,d) initial batch of samples from initial condition dt (float): integration time step """ def __init__( self, f: Callable, L: Union[Callable, Tensor], Q: Tensor, t_span: Tuple, x0: Tensor, dt: float, ) -> None: super().__init__() bs = x0.shape[0] Q_chol = torch.linalg.cholesky(Q) * math.sqrt(dt) t = torch.arange(t_span[0], t_span[1] + dt, dt) xk = x0.clone() x_list = [xk] for tk in t[:-1]: dbeta = torch.randn(bs, Q_chol.shape[0]) @ Q_chol.t() xk = xk + f(tk, xk) * dt if callable(L): noise_samples = torch.einsum("ijk,ik->ij", L(tk, xk), dbeta) else: noise_samples = dbeta @ L.t() xk += noise_samples x_list.append(xk.clone()) self.x = torch.stack(x_list, dim=1) if x0.shape[0] > 1: self.mean = self.x.mean(0) self.var = sample_covariance(self.x.transpose(1, 0)) self.t = t
[docs]def solve_lyapunov_spectral(D: Tensor, Q: Tensor, RHS: Tensor) -> Tensor: """ Solves the lyapunov equation, XK^T + KX^T = RHS, using the spectral decomposition of K Args: D (Tensor): (..., n) array of eigen values Q (Tensor): (n,n) RHS (Tensor): (n,n) right hand side Returns: Tensor: (n,n) solution """ n = D.shape[-1] if D.dim() > 1: bs = D.shape[:-1] squeeze = False else: bs = (1,) squeeze = True bs_rhs = RHS.shape[:-2] newRHS = Q.transpose(-2, -1) @ RHS @ Q # generate all eigenvalues of K \oplus K eigs = (D.unsqueeze(-1) + D.unsqueeze(-2)).reshape(*bs, -1) # eigs = torch.kron(D, torch.ones(n)) + torch.kron(torch.ones(n), D) vecX = newRHS.transpose(-2, -1).reshape(*bs_rhs, -1) / eigs X = Q @ vecX.reshape(*bs_rhs, n, n).transpose(-2, -1) @ Q.transpose(-2, -1) if squeeze: X = X.squeeze() return X
[docs]def solve_lyapunov_diag(K_diag: Tensor, RHS_diag: Tensor) -> Tensor: """ Solves the lyapunov equation XK^T + KX^T = RHS in the case that K and RHS are diagonal (order(N) operations) Args: K_diag (Tensor): (n,) diagonal of K RHS_diag (Tensor): (n,) diagonal of RHS Returns: Tensor: (n,n) matrix solution """ # batch-wise case for SVI if RHS_diag.dim() == 3: pass elif K_diag.shape[0] != RHS_diag.shape[0]: # todo: what case was this for? RHS_diag = RHS_diag.unsqueeze(1) return 0.5 * RHS_diag / K_diag
def mean_diagonal_gaussian_loglikelihood( x: Tensor, mu: Tensor, var: Tensor, log2pi: Tensor = None ) -> Tensor: """ Computes the mean log likelihood for a batch of data points given a batch of means and variances Args: x (Tensor): (bs, d) batch of data points mu (Tensor): (ns, bs, d) batch of means var (Tensor): (d,) variances Returns: Tensor: mean gaussian log likelihood """ k = x.shape[-1] if log2pi is None: log2pi = torch.log(torch.tensor(2 * math.pi, dtype=x.dtype, device=x.device)) diff = ((x - mu).pow(2) / var).sum(-1) return -0.5 * diff.mean() - k / 2 * log2pi - 0.5 * torch.log(var.prod()) def grad_skew_expm(V, exp_D, V_trns_conj, D, triu_indices, d_indices, dupl_mat): # https://www.janmagnus.nl/wips/expo-23.pdf d = len(d_indices) Delta = torch.zeros_like(V) diff = (exp_D[..., triu_indices[0]] - exp_D[..., triu_indices[1]]) / ( D[..., triu_indices[0]] - D[..., triu_indices[1]] ) Delta[..., triu_indices[0], triu_indices[1]] = diff Delta[..., triu_indices[1], triu_indices[0]] = diff Delta[..., d_indices, d_indices] = exp_D if V.dim() == 2: R = torch.kron(V.conj(), V) R_inv_dupl = torch.kron(V.transpose(-2, -1), V_trns_conj) @ dupl_mat elif V.dim() == 3: # batch wise kronecker product R = torch.einsum("ijk,ilm->ijlkm", V.conj(), V).view(-1, d**2, d**2) R_inv_dupl = ( torch.einsum("ijk,ilm->ijlkm", V.transpose(-2, -1), V_trns_conj).view( -1, d**2, d**2 ) @ dupl_mat ) else: raise ValueError("S must be 2 or 3 dimensional") grad_exp_S = ( R * Delta.flatten(start_dim=-2, end_dim=-1).unsqueeze(-2) @ (R_inv_dupl) ) return grad_exp_S.real class FastSkewSymMatrixExp(torch.autograd.Function): @staticmethod def forward(ctx, S_vec, triu_indices, d_indicies, dupl_mat): d = len(d_indicies) S = torch.zeros((*S_vec.shape[:-1], d, d)) S[..., triu_indices[0], triu_indices[1]] = -S_vec S[..., triu_indices[1], triu_indices[0]] = S_vec D, V = torch.linalg.eigh(1.0j * S) # j * S is Hermitian D = D / 1.0j exp_D = torch.exp(D) V_trns_conj = V.transpose(-2, -1).conj() ctx.save_for_backward( V, exp_D, V_trns_conj, D, triu_indices, d_indicies, dupl_mat ) return (V * exp_D.unsqueeze(-2) @ V_trns_conj).real @staticmethod def backward(ctx, gradoutputs): bs = gradoutputs.shape[:-2] grad = ( gradoutputs.transpose(-2, -1).reshape(*bs, 1, -1)
[docs] @ grad_skew_expm(*ctx.saved_tensors) ).squeeze(-2) return (grad, None, None, None, None) fast_skew_sym_matrix_exp = FastSkewSymMatrixExp.apply class SkewSymMatrixExp(nn.Module): """ This module computes the matrix exponential of a skew symmetric matrix along with its gradient In testing this is about 10x-100x faster than autograd """ def __init__(self, d: int): super().__init__() self.d = d self.register_buffer("dupl_mat", skew_sym_duplication_matrix(d) + 0j) self.register_buffer("ind", torch.triu_indices(d, d, offset=+1)) self.register_buffer("d_ind", torch.arange(d))
[docs] def forward(self, S_vec: Tensor, return_grad: bool = False): """Computes matrix exp given n(n-1)/2 input vector self ([TODO:parameter]): [TODO:description] S_vec (Tensor): batch of flattened n(n-1)/2 Tensor defining skew sym matrices return_grad (bool): flag indicating whether to return output or output and grad Returns: Union[Tensor, Tuple[Tensor, Tensor]]: exp(S_vec) or (exp(S_vec), grad(exp(S_vec))) """ exp_S = fast_skew_sym_matrix_exp(S_vec, self.ind, self.d_ind, self.dupl_mat) if not return_grad: return exp_S else: grad_exp_S = grad_skew_expm(*exp_S.grad_fn.saved_tensors) return exp_S, grad_exp_S
# ------------------------------------------------------------------------------ # logm from https://github.com/pytorch/pytorch/issues/9983 # ------------------------------------------------------------------------------ def adjoint(A, E, f): A_H = A.T.conj().to(E.dtype) n = A.size(0) M = torch.zeros(2 * n, 2 * n, dtype=E.dtype, device=E.device) M[:n, :n] = A_H M[n:, n:] = A_H M[:n, n:] = E return f(M)[:n, n:].to(A.dtype) def logm_scipy(A): return torch.from_numpy(scipy.linalg.logm(A.cpu(), disp=False)[0]).to(A.device) class Logm(torch.autograd.Function): @staticmethod def forward(ctx, A): assert A.ndim == 2 and A.size(0) == A.size(1) # Square matrix assert A.dtype in ( torch.float32, torch.float64, torch.complex64, torch.complex128, ) ctx.save_for_backward(A) return logm_scipy(A) @staticmethod def backward(ctx, G): (A,) = ctx.saved_tensors return adjoint(A, G, logm_scipy) logm = Logm.apply def matrix_log(A: Tensor) -> Tensor: if A.dim() == 3: return torch.stack([logm(Ai) for Ai in A], dim=0) else: return logm(A)