Source code for svise.sde_learning._marginal_sde

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor

# from functorch import vmap, jacfwd
import math
from typing import Union, Tuple, Callable, Iterator, List, Optional
from abc import ABC, abstractmethod
from ..utils import *
from ..kernels import (
    Matern52,
    Matern12,
    Matern52withGradients,
    KumaraswamyWarping,
    IdentityWarp,
)
from ._diffusion_prior import *
import warnings
import numpy as np

__all__ = [
    "MarginalModel",
    "MarginalFCNN",
    "MeanFCNN",
    "StrictlyPositiveFCNN",
    "OrthogonalFCNN",
    "MarginalGLM",
    "MeanGLM",
    "StrictlyPositiveGLM",
    "OrthogonalGLM",
    "MarginalSDE",
    "DiagonalMarginalSDE",
    "SpectralMarginalSDE",
]


class MarginalModel(ABC, nn.Module):
    """
    Abstract base class for a function that returns
    the marginal statistics of a Markov GP
    """

    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def forward(
        self, t: Tensor, return_grad=False
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        """Returns the either the output or output + gradient
        at the time stamps, t
            self ([TODO:parameter]): [TODO:description]
            return_grad (bool, optional ): flag indicating whether to just output or output and grad
            t (Tensor): batch of time stamps

        Returns:
            output or output and grads
        """
        pass


class MarginalFCNN(MarginalModel):
    """
    DEPRECATED
    """

    def __init__(
        self,
        num_outputs,
        t_span,
        num_hidden,
        num_layers,
        output_transform: Callable = Identity(),
    ):
        super().__init__()
        batch_norm = False
        nonlinearity = nn.SiLU()
        self.fcnn = FCNN(
            1,
            num_hidden,
            num_outputs,
            num_layers,
            nonlinearity=nonlinearity,
            output_transform=output_transform,
            batch_norm=batch_norm,
        )
        if not isinstance(t_span[0], Tensor):
            dt = torch.tensor(t_span[1] - t_span[0])
        else:
            dt = (t_span[1] - t_span[0]).detach()
        self.register_buffer("dt", dt)

    def forward(
        self, t: Tensor, return_grad=False
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        # normalize inputs (assumes t is uniformly sampled)
        t = t.mul(2.0 / self.dt).add(-1.0).mul(math.sqrt(3))
        if t.dim() == 0:
            t = t.unsqueeze(-1)
        # turn on gradient tracking
        flip_grad = False
        if not t.requires_grad and return_grad:
            t.requires_grad_(True)
            flip_grad = True
        m = self.fcnn(t)
        if not return_grad:
            return m
        else:
            dmdt = (
                multioutput_gradient(m, t, vmap=False)
                .squeeze(-2)
                .mul(2.0 / self.dt * math.sqrt(3))
            )
            if flip_grad:
                t.requires_grad_(False)
            return (m, dmdt)


class MeanFCNN(MarginalFCNN):
    def __init__(
        self,
        num_outputs,
        t_span,
        num_hidden,
        num_layers,
        train_t: Tensor = None,
        train_x: Tensor = None,
    ):
        super().__init__(num_outputs, t_span, num_hidden, num_layers)
        if train_t is not None:
            assert train_x is not None, "train_x and train_y must both be provided"
            t, y = (train_t, train_x)
            self.train_fcnn(t, y)
        elif train_x is not None:
            raise ValueError("train_x and train_y must both be provided.")

    def forward(
        self, t: Tensor, return_grad=False
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        return super().forward(t, return_grad=return_grad)

    def train_fcnn(self, train_t, train_x):
        train_dataset = TensorDataset(train_t, train_x)
        num_mc_samples = 128
        train_loader = DataLoader(
            train_dataset, batch_size=num_mc_samples, shuffle=True
        )
        # num_epochs = num_iters // len(train_loader)
        num_epochs = 2000
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        for j in range(num_epochs):
            for t_batch, y_batch in train_loader:
                optimizer.zero_grad()
                y_pred = self.forward(t_batch)
                (y_pred - y_batch).pow(2).sum(-1).mean().backward()
                optimizer.step()


class StrictlyPositiveFCNN(MarginalFCNN):
    def __init__(self, num_outputs, t_span, num_hidden, num_layers):
        super().__init__(
            num_outputs, t_span, num_hidden, num_layers, output_transform=Positive()
        )
        self.fcnn.mlp[-2].bias.data.fill_(1.0)

    def forward(
        self, t: Tensor, return_grad=False
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        if return_grad:
            m, dmdt = super().forward(t, return_grad=True)
            return (m, dmdt)
        else:
            return super().forward(t, return_grad=False)


class OrthogonalFCNN(MarginalFCNN):
    def __init__(self, num_outputs, t_span, num_hidden, num_layers):
        num_skew = num_outputs * (num_outputs - 1) // 2
        super().__init__(num_skew, t_span, num_hidden, num_layers)
        self.orthogonal_mexp = SkewSymMatrixExp(num_outputs)
        self.num_outputs = num_outputs

    def forward(
        self, t: Tensor, return_grad=False
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        if not return_grad:
            m = super().forward(t, return_grad=False)
            return self.orthogonal_mexp(m, return_grad=False)
        else:
            m, dmdt = super().forward(t, return_grad=True)
        expS, grad_expS = self.orthogonal_mexp(m, return_grad=True)
        dexpSdt = (grad_expS @ (dmdt).unsqueeze(-1)).squeeze(-1)
        return (
            expS,
            dexpSdt.view(len(t), self.num_outputs, self.num_outputs).transpose(-2, -1),
        )


class MarginalGLM(MarginalModel):
    def __init__(
        self,
        num_outputs,
        t_span,
        n_tau,
        learn_inducing_locations: bool = False,
        whitened_param: bool = True,
        kernel: str = "matern52",
        apply_input_warping: bool = True,
        len_init: float = 1.0,
        dynamic_update_bounds: bool = True,
    ) -> None:
        super().__init__()
        nu = 1.0
        tau = torch.linspace(t_span[0] - nu, t_span[1] + nu, n_tau).unsqueeze(-1)
        if apply_input_warping:
            input_warping = KumaraswamyWarping(
                (t_span[0] - nu, t_span[1] + nu),
                dynamic_update_bounds=dynamic_update_bounds,
            )
        else:
            input_warping = IdentityWarp()
        self.learn_inducing_locations = learn_inducing_locations
        self.whitened_param = whitened_param
        if learn_inducing_locations:
            self.tau = Parameter(tau)
        else:
            self.register_buffer("tau", tau)
        # if init_with_grad = True, K(tau, tau) will be size (n_tau, 2*n_tau)
        init_with_grad = False
        if kernel == "matern52":
            self.K = Matern52(input_warping=input_warping, len_init=len_init)
        elif kernel == "matern52withgradients":
            init_with_grad = True
            self.K = Matern52withGradients(
                input_warping=input_warping, len_init=len_init
            )
        elif kernel == "matern12":
            self.K = Matern12(input_warping=input_warping, len_init=len_init)
        else:
            raise ValueError(f"Unknown kernel {kernel}")
        self.register_buffer("eps", torch.tensor(1e-6))
        if not init_with_grad:
            K_tmp = self.K(self.tau, self.tau)
        else:
            warnings.warn(
                "Learning inducing locations is not supported with gradient glms."
            )
            K_tuple = self.K(self.tau, self.tau, return_grad=True)
            K_tmp = torch.cat(K_tuple, dim=0)
        n_weights = K_tmp.shape[-1]
        C = torch.linalg.cholesky(K_tmp + self.eps * torch.eye(n_weights))
        self.register_buffer("C_buffer", C.detach())
        self.raw_w = Parameter(torch.Tensor(n_weights, num_outputs))
        self.b = Parameter(torch.Tensor(num_outputs))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.whitened_param:
            nn.init.normal_(self.raw_w, mean=0.0, std=1.0)
        else:
            w_init = torch.triangular_solve(self.raw_w, self.C.t()).solution
            self.raw_w.data.copy_(w_init)
        nn.init.zeros_(self.b)

    @property
    def n_tau(self) -> int:
        return len(self.tau)

    @property
    def C(self):
        if self.learn_inducing_locations:
            C_out = torch.linalg.cholesky(
                self.K(self.tau, self.tau) + self.eps * torch.eye(self.n_tau)
            )
        else:
            C_out = self.C_buffer
        return C_out

    @property
    def w(self):
        if self.whitened_param:
            return torch.triangular_solve(self.raw_w, self.C.t()).solution
        else:
            return self.raw_w

    @w.setter
    def w(self, value: Tensor):
        if self.whitened_param:
            self.raw_w.data.copy_(self.C.t() @ value)
        else:
            self.raw_w.data.copy_(value)

    def forward(
        self, t: Tensor, return_grad=False
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        if t.dim() == 0:
            t = t.unsqueeze(-1)
        if not return_grad:
            return self.K(t, self.tau) @ self.w + self.b
        else:
            m, dmdt = self.K(t, self.tau, return_grad=True)
            return m @ self.w + self.b, dmdt @ self.w


def train_glm(
    train_x,
    train_y,
    train_dydx,
    num_outputs,
    t_span,
    n_tau,
    whitened_param,
    learn_inducing_locations,
    kernel,
):
    len_init_list = [1e-1, 0.5, 1.0, 10.0]
    eval_error = {len_init: [] for len_init in len_init_list}
    n_splits = 5
    idx = np.arange(len(train_x))
    np.random.shuffle(idx)
    n = len(train_x) // n_splits
    split_idx = [idx[n * i : n * (i + 1)] for i in range(n_splits)]
    for len_init in len_init_list:
        glm = MarginalGLM(
            num_outputs,
            t_span,
            n_tau,
            whitened_param=whitened_param,
            learn_inducing_locations=learn_inducing_locations,
            kernel=kernel,
            len_init=len_init,
        )
        for split in split_idx:
            train_idx = np.array(list(set(idx) - set(split)))
            t = train_x[train_idx]
            y = train_y[train_idx]
            t_eval = train_x[split]
            y_eval = train_y[split]
            if train_dydx is not None:
                y = torch.cat([y, train_dydx[train_idx]], dim=0)
                y_eval = torch.cat([y_eval, train_dydx[split]], dim=0)
                Kttau, dKttaudt = glm.K(t, glm.tau, return_grad=True)
                train_features = torch.cat([Kttau, dKttaudt], dim=0)
                Kvttau, dKvttaudt = glm.K(t_eval, glm.tau, return_grad=True)
                val_features = torch.cat([Kvttau, dKvttaudt], dim=0)
            else:
                train_features = glm.K(t, glm.tau)
                val_features = glm.K(t_eval, glm.tau)
            w = solve_least_squares(train_features, y, gamma=1e-1)  # , clamp_min=1e-6)
            eval_loss = (val_features @ w - y_eval).pow(2).mean().detach().numpy()
            eval_error[len_init].append(eval_loss)
        eval_error[len_init] = np.array(eval_error[len_init]).mean()
    best_len = len_init_list[0]
    eval_best = eval_error[best_len]
    for len_init in len_init_list:
        if eval_error[len_init] < eval_best:
            best_len = len_init
            eval_best = eval_error[len_init]
    return best_len


class MeanGLM(MarginalGLM):
    def __init__(
        self,
        num_outputs,
        t_span,
        n_tau,
        learn_inducing_locations: bool = False,
        whitened_param: bool = True,
        kernel: str = "matern52",
        train_x: Tensor = None,
        train_y: Tensor = None,
        train_dydx: Tensor = None,
    ) -> None:
        if train_x is not None:
            assert train_y is not None, "train_x and train_y must both be provided"
            # determinte a good initialization for the length scale
            len_init = train_glm(
                train_x,
                train_y,
                train_dydx,
                num_outputs,
                t_span,
                n_tau,
                whitened_param,
                learn_inducing_locations,
                kernel,
            )
        else:
            len_init = 1.0
        super().__init__(
            num_outputs,
            t_span,
            n_tau,
            whitened_param=whitened_param,
            learn_inducing_locations=learn_inducing_locations,
            kernel=kernel,
            len_init=len_init,
        )
        if train_x is not None:
            assert train_y is not None, "train_x and train_y must both be provided"
            t, y = (train_x, train_y)
            if train_dydx is not None:
                y = torch.cat([y, train_dydx], dim=0)
                K, dKdt = self.K(t, self.tau, return_grad=True)
                features = torch.cat([K, dKdt], dim=0)
            else:
                features = self.K(t, self.tau)
            # w = solve_least_squares(features, y, gamma=1e-1, clamp_min=1e-6)
            w = solve_least_squares(features, y, gamma=1e-1)  # , clamp_min=1e-32)
            self.w = w
            self.b.data.copy_(torch.zeros_like(self.b.data))
        elif train_y is not None:
            raise ValueError("train_x and train_y must both be provided.")

    def forward(
        self, t: Tensor, return_grad=False
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        if t.dim() == 0:
            t = t.unsqueeze(-1)
        if not return_grad:
            return self.K(t, self.tau) @ self.w + self.b
        else:
            m, dmdt = self.K(t, self.tau, return_grad=True)
            return m @ self.w + self.b, dmdt @ self.w


class StrictlyPositiveGLM(MarginalGLM):
    def __init__(
        self,
        num_outputs,
        t_span,
        n_tau,
        learn_inducing_locations: bool = False,
        kernel: str = "matern52",
    ) -> None:
        super().__init__(
            num_outputs,
            t_span,
            n_tau,
            learn_inducing_locations=learn_inducing_locations,
            kernel=kernel,
        )
        self.output_transform = Positive()
        nn.init.constant_(self.b, -2.2522)
        self.w = torch.zeros(self.n_tau, num_outputs)

    def forward(
        self, t: Tensor, return_grad=False
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        if t.dim() == 0:
            t = t.unsqueeze(-1)
        if not return_grad:
            return self.output_transform(self.K(t, self.tau) @ self.w + self.b)
        else:
            m, dmdt = self.K(t, self.tau, return_grad=True)
            res, grad_res = self.output_transform(m @ self.w + self.b, return_grad=True)
            return res, grad_res * (dmdt @ self.w)


class OrthogonalGLM(MarginalGLM):
    def __init__(
        self,
        num_outputs,
        t_span,
        n_tau,
        learn_inducing_locations: bool = False,
        kernel: str = "matern52",
    ) -> None:
        super().__init__(
            num_outputs * (num_outputs - 1) // 2,
            t_span,
            n_tau,
            learn_inducing_locations=learn_inducing_locations,
            kernel=kernel,
        )
        self.orthogonal_mexp = SkewSymMatrixExp(num_outputs)
        self.num_outputs = num_outputs
        nn.init.constant_(self.w, 1e-6)

    def forward(
        self, t: Tensor, return_grad=False
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        if t.dim() == 0:
            t = t.unsqueeze(-1)
        if not return_grad:
            return self.orthogonal_mexp(self.K(t, self.tau) @ self.w + self.b)
        else:
            m, dmdt = self.K(t, self.tau, return_grad=True)
            expS, grad_expS = self.orthogonal_mexp(
                m @ self.w + self.b, return_grad=True
            )
            # todo: grad_expS.mul(dmdt @ self.w)?
            dexpSdt = (grad_expS @ (dmdt @ self.w).unsqueeze(-1)).squeeze(-1)
            return (
                expS,
                dexpSdt.view(len(t), self.num_outputs, self.num_outputs).transpose(
                    -2, -1
                ),
            )


[docs]class MarginalSDE(ABC, nn.Module): """ Abstract base class for a model of the marginal statistics of a Markov Gaussian process. """ def __init__(self) -> None: super().__init__() self.device_var = Parameter(torch.empty(0)) self.low_rank_cov = False # bool indicating if covariance is low rank @property def device(self): return self.device_var.device
[docs] @abstractmethod def mean_parameters(self) -> Iterator[Parameter]: """ Returns an iterator over the mean parameters of the marginal SDE. Returns: Iterator[Parameter]: An iterator over the mean parameters of the marginal SDE. """ pass
[docs] @abstractmethod def generate_samples(self, t: Tensor, num_samples: int, *args, **kwargs) -> Tensor: """ Generates samples from the approximating SDE marginal distribution at time t and optionally returns some intermediate quantities. Args: t (Tensor): (bs, ) time stamps at which to generate samples num_samples (int): number of independent samples at each time stamp Returns: Tensor: samples of latent states """ pass
[docs] @abstractmethod def forward( self, t: Tensor, f: Callable[[Tensor, Tensor], Tensor], num_samples: int ) -> Tensor: """ Computes the (unweighted) residual loss between the approximating and prior SDEs. Args: t (Tensor): (bs, ) time stamps at which to generate samples f (Callable[[Tensor,Tensor],Tensor]): (t, z) -> (num_samples, bs, d) or (bs, d) prior SDE evaluations num_samples (int): number of reparameterized samples at each time stamp Returns: Tensor: (bs,) approximate residual loss """ pass
[docs] @abstractmethod def K(self, t: Tensor) -> Tensor: """ Returns the the covariance matrix K(t) evaluated at a batch of times. Args: t (Tensor): Time stamps at which to compute the covariance matrix Returns: Tensor: (len(b), d, d) batch of covariance matrices """ pass
[docs] @abstractmethod def mean(self, t: Tensor, return_grad=False) -> Tensor: """ Returns the marginal mean at time t Args: t (Tensor): (bs,) time stamps at which to evaluate mean Returns: Tensor: (bs, d) mean evaluated at times """ pass
[docs]class DiagonalMarginalSDE(MarginalSDE): """ A model for the marginal statistics of a Markov GP whose covariance is diagonal. Args: d (int): dimension of the state t_span (Union[Tuple[float, float], Tuple[Tensor, Tensor]]): time span of data diffusion_prior (DiagonalDiffusionPrior): Some diagonal diffusion prior model_form (str, optional): GLM or FCNN (FCNN not fully tested) vmap (bool, optional): DEPRECATED **kwargs: arguments passed onto GLM / FCNN """ def __init__( self, d: int, t_span: Union[Tuple[float, float], Tuple[Tensor, Tensor]], diffusion_prior: DiagonalDiffusionPrior, model_form: str = "GLM", vmap: bool = False, **kwargs, ) -> None: super().__init__() self.vmap = vmap t0 = t_span[0] tf = t_span[1] self.d = d self.diffusion_prior = diffusion_prior train_x = kwargs.get("train_x", None) train_y = kwargs.get("train_y", None) train_dydx = kwargs.get("train_dydx", None) if model_form == "GLM": kernel = kwargs.get("kernel", "matern52") n_tau = kwargs.get("n_tau", 100) learn_inducing_locations = kwargs.get("learn_inducing_locations", False) self.m = MeanGLM( d, (t0, tf), n_tau, learn_inducing_locations=learn_inducing_locations, kernel=kernel, train_x=train_x, train_y=train_y, train_dydx=train_dydx, ) kernel = "matern52" self.K_diag = StrictlyPositiveGLM( d, (t0, tf), n_tau, learn_inducing_locations=learn_inducing_locations, kernel=kernel, ) elif model_form == "FCNN": hidden_size = kwargs.get("hidden_size", 50) num_layers = kwargs.get("num_layers", 2) self.m = MeanFCNN( d, (t0, tf), hidden_size, num_layers, train_t=train_x, train_x=train_y ) self.K_diag = StrictlyPositiveFCNN(d, (t0, tf), hidden_size, num_layers) else: raise ValueError(f"Invalid model_form: {model_form}")
[docs] def mean(self, t: Tensor, return_grad=False) -> Tensor: """ Returns the marginal mean at time t Args: t (Tensor): (bs,) time stamps at which to evaluate mean Returns: Tensor: (bs, d) mean evaluated at times """ if return_grad: return self.m(t, return_grad=True) else: return self.m(t)
[docs] def mean_parameters(self) -> Iterator[Parameter]: """ Returns an iterator over the mean parameters of the marginal SDE. Returns: Iterator[Parameter]: An iterator over the mean parameters of the marginal SDE. """ return self.m.parameters()
[docs] def K(self, t: Tensor) -> Tensor: """ Compute the marginal covariance matrix at time t. Args: t (Tensor): Time stamps at which to compute the covariance matrix Returns: Tensor: (len(b), d, d) batch of covariance matrices """ return self.K_diag(t).diag_embed()
[docs] def drift(self, t: Tensor, z: Tensor) -> Tensor: """ compute the drift function of the equivalent SDE at times t. Args: t (Tensor): (bs,) time stamps at which to compute the drift z (Tensor): (..., bs, d) batch of latent states Returns: Tensor: (..., bs, d) batch of drift function evaluations """ # computing parameterization and time derivative m, dmdt = self.m(t, return_grad=True) K_diag, dKdt_diag = self.K_diag(t, return_grad=True) p_noise = self.diffusion_prior.process_noise_diag # todo: check if this is correct for mle case if dKdt_diag.shape[0] != p_noise.shape[0] and p_noise.dim() == 2: p_noise = p_noise.unsqueeze(1) A = solve_lyapunov_diag(K_diag, -dKdt_diag + p_noise) return dmdt - (A * (z - m))
[docs] def unweighted_residual_loss(self, drift: Tensor, f: Tensor) -> Tensor: """ Computes the (unweighted) residual loss between the approximating and prior SDEs. Args: drift (Tensor): (n_samples, bs, d) drift function evaluations f (Tensor): (n_samples, bs, d) / (bs, d) prior SDE evaluations Returns: Tensor: residual loss """ r = drift - f return self.diffusion_prior.solve_linear(r.pow(2)).sum(-1)
[docs] def generate_samples(self, t: Tensor, num_samples: int) -> Tensor: """ Generates samples from the approximating SDE marginal distribution at time t and optionally returns some intermediate quantities. Args: t (Tensor): (bs, ) time stamps at which to generate samples num_samples (int): number of independent samples at each time stamp return_intermediates (bool, optional): indicates whether to return intermediate quanties. Defaults to False. Returns: Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]: samples of latent states or samples of latent states and intermediate quantities """ v = torch.randn(num_samples, len(t), self.d, device=self.device) # get mean at each time stamp return self.mean(t) + (self.K_diag(t).sqrt().mul(v))
[docs] def forward( self, t: Tensor, f: Callable[[Tensor, Tensor], Tensor], num_samples: int ) -> Tensor: """ Computes the (unweighted) residual loss between the approximating and prior SDEs. Args: t (Tensor): (bs, ) time stamps at which to generate samples f (Callable[[Tensor,Tensor],Tensor]): (t, z) -> (num_samples, bs, d) or (bs, d) prior SDE evaluations num_samples (int): number of reparameterized samples at each time stamp Returns: Tensor: (bs,) approximate residual loss """ # compute residual loss making good use of intermediate quantities zs = self.generate_samples(t, num_samples) drift = self.drift(t, zs) f_samples = f(t, zs) return self.unweighted_residual_loss(drift, f_samples).mean(0)
[docs]class SpectralMarginalSDE(MarginalSDE): """ A model for the marginal statistics of a Markov GP using the spectral parametrization described in the main text. Args: d (int): dimension of the state t_span (Union[Tuple[float, float], Tuple[Tensor, Tensor]]): time span of data diffusion_prior (DiffusionPrior): Some diffusion prior model_form (str, optional): GLM or FCNN (FCNN not fully tested) vmap (bool, optional): DEPRECATED **kwargs: arguments passed onto GLM / FCNN """ def __init__( self, d: int, t_span: Union[Tuple[float, float], Tuple[Tensor, Tensor]], diffusion_prior: DiffusionPrior, model_form: str = "GLM", vmap: bool = False, **kwargs, ) -> None: super().__init__() t0 = t_span[0] tf = t_span[1] self.vmap = vmap self.d = d self.register_buffer("tril_ind", torch.tril_indices(d, d, offset=-1)) self.diffusion_prior = diffusion_prior self.register_buffer("ident", torch.eye(d)) train_x = kwargs.get("train_x", None) train_y = kwargs.get("train_y", None) if model_form == "GLM": kernel = kwargs.get("kernel", "matern52") n_tau = kwargs.get("n_tau", 100) learn_inducing_locations = kwargs.get("learn_inducing_locations", False) self.m = MeanGLM( d, (t0, tf), n_tau, kernel=kernel, learn_inducing_locations=learn_inducing_locations, train_x=train_x, train_y=train_y, ) self.orthogonal = OrthogonalGLM( d, (t0, tf), n_tau, kernel=kernel, learn_inducing_locations=learn_inducing_locations, ) self.eigenvals = StrictlyPositiveGLM( d, (t0, tf), n_tau, kernel=kernel, learn_inducing_locations=learn_inducing_locations, ) elif model_form == "FCNN": hidden_size = kwargs.get("hidden_size", 50) num_layers = kwargs.get("num_layers", 2) # notation for train_t vs train_x is confusing self.m = MeanFCNN( d, (t0, tf), hidden_size, num_layers, train_t=train_x, train_x=train_y ) self.orthogonal = OrthogonalFCNN(d, (t0, tf), hidden_size, num_layers) self.eigenvals = StrictlyPositiveFCNN(d, (t0, tf), hidden_size, num_layers) else: raise ValueError(f"Invalid model_form: {model_form}")
[docs] def mean(self, t: Tensor, return_grad=False) -> Tensor: """ Returns the marginal mean at time t Args: t (Tensor): (bs,) time stamps at which to evaluate mean Returns: Tensor: (bs, d) mean evaluated at times """ if return_grad: return self.m(t, return_grad=True) else: return self.m(t)
[docs] def mean_parameters(self) -> Iterator[Parameter]: """ Returns an iterator over the mean parameters of the marginal SDE. Returns: Iterator[Parameter]: An iterator over the mean parameters of the marginal SDE. """ return self.m.parameters()
[docs] def K(self, t: Tensor) -> Tensor: """ Compute the marginal covariance matrix at time t. Args: t (Tensor): Time stamps at which to compute the covariance matrix Returns: Tensor: (len(b), d, d) batch of covariance matrices """ U = self.orthogonal(t) D = self.eigenvals(t) return (U.mul(D.unsqueeze(-2))) @ U.transpose(-2, -1)
def _dKdt_fast(self, U: Tensor, dUdt: Tensor, D: Tensor, dDdt: Tensor) -> Tensor: """ Compute the time derivative of the marginal covariance matrix using intermediate quantities. Args: U (Tensor): (bs, d, d) batch of orthogonal matrices dUdt (Tensor): (bs, d, d) batch of time derivatives of orthogonal matrices D (Tensor): (bs,d) batch of eigenvalues of the covariance matrix dDdt (Tensor): (bs,d) batch of time derivatives of eigenvalues of the covariance matrix Returns: Tensor: (bs, d, d) batch of time derivatives of the marginal covariance matrix """ F1 = (U.mul(D.unsqueeze(-2))) @ dUdt.transpose(-2, -1) F2 = dUdt.mul(D.unsqueeze(-2)) F3 = U.mul(dDdt.unsqueeze(-2)) dKdt = F1 + (F2 + F3) @ U.transpose(-2, -1) return dKdt
[docs] def drift(self, t: Tensor, z: Tensor) -> Tensor: """ compute the drift function of the equivalent SDE at times t. Args: t (Tensor): (bs,) time stamps at which to compute the drift z (Tensor): (..., bs, d) batch of latent states Returns: Tensor: (..., bs, d) batch of drift function evaluations """ # computing parameterization and time derivative U, dUdt = self.orthogonal(t, return_grad=True) D, dDdt = self.eigenvals(t, return_grad=True) return self._drift_fast(t, z, U, dUdt, D, dDdt)
def _drift_fast( self, t: Tensor, z: Tensor, U: Tensor, dUdt: Tensor, D: Tensor, dDdt: Tensor ) -> Tensor: """ Compute the drift function of the equivalent SDE at times t using intermediate quantities. Args: t (Tensor): (bs,) time stamps at which to compute the drift z (Tensor): (..., bs, d) batch of latent states U (Tensor): (bs, d, d) batch of orthogonal matrices dUdt (Tensor): (bs, d, d) batch of time derivatives of orthogonal matrices D (Tensor): (bs,d) batch of eigenvalues of the covariance matrix dDdt (Tensor): (bs,d) batch of time derivatives of eigenvalues of the covariance matrix Returns: Tensor: (..., bs, d) batch of drift function evaluations """ # computing time derivative of K dKdt = self._dKdt_fast(U, dUdt, D, dDdt) p_noise = self.diffusion_prior.process_noise if dKdt.shape[0] != p_noise.shape[0] and p_noise.dim() == 3: p_noise = p_noise.unsqueeze(1) A = solve_lyapunov_spectral(D, U, -dKdt + p_noise) # computing mean and time derivitve of mean m, dmdt = self.m(t, return_grad=True) # squeeze for compat with integration schemes return dmdt - (A @ (z - m).unsqueeze(-1)).squeeze()
[docs] def unweighted_residual_loss(self, drift: Tensor, f: Tensor) -> Tensor: """ Computes the (unweighted) residual loss between the approximating and prior SDEs. Args: drift (Tensor): (n_samples, bs, d) drift function evaluations f (Tensor): (n_samples, bs, d) / (bs, d) prior SDE evaluations Returns: Tensor: residual loss """ r = drift - f return (r.mul(self.diffusion_prior.solve_linear(r))).sum(-1)
[docs] def generate_samples( self, t: Tensor, num_samples: int, return_intermediates: bool = False ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]: """ Generates samples from the approximating SDE marginal distribution at time t and optionally returns some intermediate quantities. Args: t (Tensor): (bs, ) time stamps at which to generate samples num_samples (int): number of independent samples at each time stamp return_intermediates (bool, optional): indicates whether to return intermediate quanties. Defaults to False. Returns: Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]: samples of latent states or samples of latent states and intermediate quantities """ # get covariance at each time stamp if return_intermediates: U, dUdt = self.orthogonal(t, return_grad=True) D, dDdt = self.eigenvals(t, return_grad=True) else: U = self.orthogonal(t, return_grad=False) D = self.eigenvals(t, return_grad=False) # (len(t), d, d) L = U.mul(D.sqrt().unsqueeze(-2)) # should I generate samples for at each time stamp, or can they be shared between stamps # for now we will generate independent samples at each time stamp v = torch.randn(num_samples, len(t), self.d, 1) # get mean at each time stamp mu = self.mean(t) zs = mu + (L @ v).squeeze(-1) if return_intermediates: return zs, U, dUdt, D, dDdt else: return zs
[docs] def forward( self, t: Tensor, f: Callable[[Tensor, Tensor], Tensor], num_samples: int ) -> Tensor: """ Computes the (unweighted) residual loss between the approximating and prior SDEs. Args: t (Tensor): (bs, ) time stamps at which to generate samples f (Callable[[Tensor,Tensor],Tensor]): (t, z) -> (num_samples, bs, d) or (bs, d) prior SDE evaluations num_samples (int): number of reparameterized samples at each time stamp Returns: Tensor: (bs,) approximate residual loss """ # compute residual loss making good use of intermediate quantities zs, U, dUdt, D, dDdt = self.generate_samples( t, num_samples, return_intermediates=True ) drift = self._drift_fast(t, zs, U, dUdt, D, dDdt) f_samples = f(t, zs) return self.unweighted_residual_loss(drift, f_samples).mean(0)