import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch import Tensor
from typing import List, Union, Tuple, Callable, Optional
from abc import ABC, abstractmethod
from ..utils import *
from ..quadrature._unbiased_quadrature import QuadRule1D
from ..quadrature import UnbiasedGaussLegendreQuad
from ._marginal_sde import *
from ._diffusion_prior import *
from ._likelihood import *
from ._sde_prior import *
from ..variationalsparsebayes.sparse_glm import SparsePolynomialFeatures
__all__ = [
"SDELearner",
"SparsePolynomialSDE",
"SparsePolynomialIntegratorSDE",
"StateEstimator",
"NeuralSDE",
]
[docs]class SDELearner(nn.Module):
"""Base class for combining different priors and likelihoods for performing
simultaneous state and governing equation discovery.
Args:
marginal_sde (MarginalSDE): Parameterization for the approximate Markov GP
likelihood (Likelihood): Likelihood for the observations
quad_scheme (QuadRule1D): 1D quadrature rule used to approximate the residual
sde_prior (SDEPrior): The prior over the drift function
diffusion_prior (DiffusionPrior): The prior over the diffusion matrix
n_reparam_samples (int): The number of samples to use when doing the reparametrization trick
"""
def __init__(
self,
marginal_sde: MarginalSDE,
likelihood: Likelihood,
quad_scheme: QuadRule1D,
sde_prior: SDEPrior,
diffusion_prior: DiffusionPrior,
n_reparam_samples: int,
) -> None:
super().__init__()
self.marginal_sde = marginal_sde
self.likelihood = likelihood
self.quad_scheme = quad_scheme
self.sde_prior = sde_prior
self.diffusion_prior = diffusion_prior
self.n_reparam_samples = n_reparam_samples
assert len(self.state_params()) + len(self.sde_params()) == len(
list(self.parameters())
), "State and SDE parameters not added correctly."
[docs] def state_params(self) -> List[Parameter]:
"""Returns parameters related to only the state estimator
Returns:
List[Parameter]: parameters related to the state estimator
"""
return list(
set(self.marginal_sde.parameters()) - set(self.diffusion_prior.parameters())
)
[docs] def sde_params(self) -> List[Parameter]:
"""Returns all parameters not related to the state estimate
Returns:
List[Parameter]: parameters not related to the state estimate
"""
return list(set(self.parameters()) - set(self.state_params()))
[docs] def residual_loss(self, t: Tensor) -> Tensor:
"""Returns the residual loss evaluated at a batch of time stamps
Args:
t (Tensor): batch of time stamps
Returns:
Tensor: residual evaulated at time stamps t
"""
return self.marginal_sde(t, self.sde_prior.drift, self.n_reparam_samples)
def combine_elbo_terms(
self, beta, N, log_like, residual_loss, kl_divergence, compat_mode=True
):
if compat_mode:
return log_like - 0.5 / N * residual_loss - beta / N * kl_divergence
return log_like - (0.5 * residual_loss + kl_divergence) * beta / N
[docs] def elbo(
self,
t: Tensor,
x: Tensor,
beta: float,
N: int,
print_loss=False,
compat_mode=True,
) -> Tensor:
"""
Computes the reweighted evidence lower bound (i.e. ELBO / N)
Args:
t (Tensor): (bs,) time stamps of measurements
x (Tensor): (bs, D) measurements
beta (float): weight of kl-divergence
N (int): total number of data points
Returns:
Tensor: normalized evidence lower bound
"""
# todo: double check, but I think that I should resample the sde and diffusion priors here only
self.sde_prior.resample_weights()
self.diffusion_prior.resample_weights()
log_like = self.likelihood.mean_log_likelihood(
t, x, self.marginal_sde, self.n_reparam_samples
)
if beta == 0:
residual_loss = self.quad_scheme(self.residual_loss)
kl_divergence = torch.tensor(0.0)
else:
residual_loss = self.quad_scheme(self.residual_loss)
kl_divergence = (
self.sde_prior.kl_divergence() + self.diffusion_prior.kl_divergence
)
if print_loss:
print(
f"log-likelihood: {log_like.item():.2f}, log-residual-likelihood: {-residual_loss*0.5/N:.2f}, kl-divergence: {kl_divergence.item():.2f}"
)
# loss = log_like - 0.5 / N * residual_loss - beta / N * kl_divergence
loss = self.combine_elbo_terms(
beta, N, log_like, residual_loss, kl_divergence, compat_mode=compat_mode
)
return loss
# methods for making predictions
[docs] def drift(self, t: Tensor, z: Tensor, summary_type: str = "sample") -> Tensor:
"""Returns the drift function for the approximate SDE
Args:
t (Tensor): time stamps to evaluate drift function
z (Tensor): states at time stamps
Returns:
Tensor: drift evaluated at time stamps and states
"""
if summary_type == "mean":
return self.sde_prior.drift(t, z).mean(0)
elif summary_type == "sample":
return self.sde_prior.drift(t, z, integration_mode=True)
else:
raise ValueError(f"summary_type {summary_type} not recognized.")
[docs] def diffusion(self, summary_type: str = "sample") -> Tensor:
"""Returns the diffusion matrix
Args:
summary_type (str, optional): "sample" or "mean"
Returns:
Tensor: the diffusion matrix
"""
if summary_type == "mean":
return torch.linalg.cholesky(self.diffusion_prior.process_noise.mean(0))
elif summary_type == "sample":
return torch.linalg.cholesky(self.diffusion_prior.process_noise)
else:
raise ValueError(f"summary_type {summary_type} not recognized.")
[docs] def resample_sde_params(self, n: int = None) -> None:
"""Resample the sde_prior and diffusion_prior
Args:
n (int, optional): number of samples to use when resampling
"""
if n is not None:
n_reparam_samples = self.sde_prior.n_reparam_samples
self.sde_prior.n_reparam_samples = n
self.diffusion_prior.n_reparam_samples = n
self.sde_prior.resample_weights()
self.diffusion_prior.resample_weights()
# reset n_reparam_samples
if n is not None:
self.sde_prior.n_reparam_samples = n_reparam_samples
self.diffusion_prior.n_reparam_samples = n_reparam_samples
[docs]class SparsePolynomialSDE(SDELearner):
"""Subclass of SDELearner for the case that the prior over the
drift function is a sparse linear combination of polynomials, the
observation matrix is linear, observation noise is a diagonal
Gaussian, and the Markov GP is parameterized by RBF models
with the Matern 5/2 kernel. Also performs some useful
initialization to make training more stable.
Args:
d (int): dimension of the state
t_span (Union[Tuple[float, float], Tuple[Tensor, Tensor]]): min and max boundary for RBF centers
degree (int): degree of polynomial in drift function
n_reparam_samples (int): number of samples to use when using the reparametrization trick
G (Tensor): observation matrix (i.e. y = Gx)
num_meas (int): number of observations (should be G.shape[0])
measurement_noise (Tensor): variance of observations
tau (float): global half-cauchy scaling parameter
train_t (Tensor, optional): observation time stamps (ns,)
train_x (Tensor, optional): observations (ns, num_meas)
input_labels (List[str], optional): name of each variable (i.e. ["x", "y", ...])
n_quad (int, optional): number of quadrature nodes to use
quad_percent (float, optional): 1 - quad_percent of the quad_nodes will be sampled uniformly in the time window
n_tau (int, optional): number of centers for the RBF models
"""
def __init__(
self,
d: int,
t_span: Union[Tuple[float, float], Tuple[Tensor, Tensor]],
degree: int,
n_reparam_samples: int,
G: Tensor,
num_meas: int,
measurement_noise: Tensor,
tau: float = 1e-5,
train_t: Tensor = None,
train_x: Tensor = None,
input_labels: List[str] = None,
n_quad: int = 128,
quad_percent: float = 0.8,
n_tau: int = 200,
) -> None:
assert quad_percent > 0 and quad_percent < 1, "quad_percent must be in (0, 1)"
Q_diag = torch.ones(d) * 1e-0
diffusion_prior = SparseDiagonalDiffusionPrior(
d, Q_diag, n_reparam_samples, tau
)
fast_initialization = (
(torch.linalg.matrix_rank(G) == d)
and (train_t is not None)
and (train_x is not None)
)
# transform = IdentityScaleTransform(d)
if fast_initialization:
stdev_transform = (torch.eye(d) - G).sum() == 0 and (train_x is not None)
if stdev_transform:
# transform = StdevScaleTransform(train_x)
measurement_noise = measurement_noise # * (1 / transform.scale ** 2)
# train_x = transform(train_x)
z = solve_least_squares(G, train_x.t(), gamma=1e-2).t()
# , clamp_min=1e-1).t()
else:
train_t = None
z = None
marginal_sde = SpectralMarginalSDE(
d,
t_span,
diffusion_prior=diffusion_prior,
model_form="GLM",
n_tau=n_tau,
learn_inducing_locations=False,
train_x=train_t,
train_y=z,
)
quad_scheme = UnbiasedGaussLegendreQuad(
t_span[0], t_span[1], n_quad, quad_percent=quad_percent
)
likelihood = IndepGaussLikelihood(G, num_meas, measurement_noise)
features = SparsePolynomialFeatures(d, degree=degree, input_labels=input_labels)
if fast_initialization:
m, dmdt = marginal_sde.mean(train_t, return_grad=True)
else:
m, dmdt = (None, None)
sde_prior = SparseMultioutputGLM(
d,
SparseFeatures=features,
n_reparam_samples=n_reparam_samples,
tau=tau,
train_x=m,
train_y=dmdt,
# transform=transform,
)
super().__init__(
marginal_sde,
likelihood,
quad_scheme,
sde_prior,
diffusion_prior,
n_reparam_samples,
)
# violates Liskov substitution principle...
# self.transform = transform
[docs] def elbo(
self, t: Tensor, x: Tensor, beta: float, N: int, print_loss=False
) -> Tensor:
return super().elbo(t, x, beta, N, print_loss=print_loss)
[docs]class SparsePolynomialIntegratorSDE(SDELearner):
"""
Assumes that states are governed by a drift function which is a function
of both the position and it's velocity but we can only observe
the states.
Like the SparsePolynomialSDE, this is a
Subclass of SDELearner for the case that the prior over the
drift function is a sparse linear combination of polynomials, the
observation matrix is linear, observation noise is a diagonal
Gaussian, the Markov GP is parameterized by RBF models
with the Matern 5/2 kernel. Also performs some useful
initialization to make training more stable.
Args:
d (int): dimension of the state
t_span (Union[Tuple[float, float], Tuple[Tensor, Tensor]]): min and max boundary for RBF centers
degree (int): degree of polynomial in drift function
n_reparam_samples (int): number of samples to use when using the reparametrization trick
G (Tensor): observation matrix (i.e. y = Gx)
num_meas (int): number of observations (should be G.shape[0])
measurement_noise (Tensor): variance of observations
tau (float): global half-cauchy scaling parameter
train_t (Tensor, optional): observation time stamps (ns,)
train_x (Tensor, optional): observations (ns, num_meas)
input_labels (List[str], optional): name of each variable (i.e. ["x", "y", ...])
n_tau (int, optional): number of centers for the RBF models
"""
def __init__(
self,
d: int,
t_span: Union[Tuple[float, float], Tuple[Tensor, Tensor]],
degree: int,
n_reparam_samples: int,
G: Tensor,
num_meas: int,
measurement_noise: Tensor,
tau: float = 1e-5,
train_t: Tensor = None,
train_x: Tensor = None,
input_labels: List[str] = None,
n_tau: int = 200,
) -> None:
Q_diag = torch.ones(d) * 1e-0
diffusion_prior = SparseDiagonalDiffusionPrior(
d, Q_diag, n_reparam_samples, tau
)
fast_initialization = (
# (torch.linalg.matrix_rank(G) == d)
(torch.eye(d)[: d // 2] - G).pow(2).sum() == 0
and (train_t is not None)
and (train_x is not None)
)
if fast_initialization:
z = solve_least_squares(G, train_x.t(), gamma=1e-2).t()
else:
train_t = None
z = None
marginal_sde = SpectralMarginalSDE(
d,
t_span,
diffusion_prior=diffusion_prior,
model_form="GLM",
n_tau=n_tau,
learn_inducing_locations=False,
train_x=train_t,
train_y=z,
)
if fast_initialization:
m, dmdt = marginal_sde.mean(train_t, return_grad=True)
G_tmp = torch.cat([G, torch.eye(d)[d // 2 :]], dim=0)
z = solve_least_squares(
G_tmp,
torch.cat([train_x, dmdt[:, : d // 2]], dim=-1).t(),
gamma=1e-2,
).t()
marginal_sde = SpectralMarginalSDE(
d,
t_span,
diffusion_prior=diffusion_prior,
model_form="GLM",
n_tau=200,
learn_inducing_locations=False,
train_x=train_t,
train_y=z,
)
quad_scheme = UnbiasedGaussLegendreQuad(
t_span[0], t_span[1], 128, quad_percent=0.8
)
likelihood = IndepGaussLikelihood(G, num_meas, measurement_noise)
features = SparsePolynomialFeatures(d, degree=degree, input_labels=input_labels)
if fast_initialization:
m, dmdt = marginal_sde.mean(train_t, return_grad=True)
else:
m, dmdt = (None, None)
sde_prior = SparseIntegratorGLM(
d,
SparseFeatures=features,
n_reparam_samples=n_reparam_samples,
tau=tau,
train_x=m,
train_y=dmdt,
)
super().__init__(
marginal_sde,
likelihood,
quad_scheme,
sde_prior,
diffusion_prior,
n_reparam_samples,
)
[docs]class StateEstimator(SDELearner):
"""
Subclass of SDELearner for the case the drift and diagional diffusion matrix is known,
the observation matrix is linear, observation noise is a diagonal
Gaussian, the Markov GP is parametrized by RBF models
with the Matern 5/2 kernel. Also performs some useful
initialization to make training more stable.
Args:
d (int): dimension of the state
t_span (Union[Tuple[float, float], Tuple[Tensor, Tensor]]): min and max boundary for RBF centers
n_reparam_samples (int): number of samples to use when using the reparametrization trick
G (Tensor): observation matrix (i.e. y = Gx)
drift (Callable[[Tensor, Tensor], Tensor]): known drift function
num_meas (int): number of observations (should be G.shape[0])
measurement_noise (Tensor): variance of observations
tau (float): global half-cauchy scaling parameter
train_t (Tensor, optional): observation time stamps (ns,)
train_x (Tensor, optional): observations (ns, num_meas)
input_labels (List[str], optional): name of each variable (i.e. ["x", "y", ...])
n_tau (int, optional): number of centers for the RBF models
Q_diag (Tensor, optional): diagonal component of the diffusion matrix,
"""
def __init__(
self,
d: int,
t_span: Union[Tuple[float, float], Tuple[Tensor, Tensor]],
n_reparam_samples: int,
G: Tensor,
drift: Callable[[Tensor, Tensor], Tensor],
num_meas: int,
measurement_noise: Tensor,
tau: float = 1e-5,
train_t: Tensor = None,
train_x: Tensor = None,
n_quad: int = 128,
quad_percent: float = 0.8,
n_tau: int = 200,
Q_diag: Tensor = None,
) -> None:
assert quad_percent > 0 and quad_percent < 1, "quad_percent must be in (0, 1)"
if Q_diag is None:
Q_diag = torch.ones(d) * 1e-0
diffusion_prior = SparseDiagonalDiffusionPrior(
d, Q_diag, n_reparam_samples, tau
)
else:
diffusion_prior = ConstantDiagonalDiffusionPrior(d, Q_diag)
fast_initialization = (
(torch.linalg.matrix_rank(G) == d)
and (train_t is not None)
and (train_x is not None)
)
# transform = IdentityScaleTransform(d)
if fast_initialization:
stdev_transform = (torch.eye(d) - G).sum() == 0 and (train_x is not None)
if stdev_transform:
# transform = StdevScaleTransform(train_x)
measurement_noise = measurement_noise # * (1 / transform.scale ** 2)
# train_x = transform(train_x)
z = solve_least_squares(G, train_x.t(), gamma=1e-2).t()
# , clamp_min=1e-1).t()
else:
train_t = None
z = None
marginal_sde = SpectralMarginalSDE(
d,
t_span,
diffusion_prior=diffusion_prior,
model_form="GLM",
n_tau=n_tau,
learn_inducing_locations=False,
train_x=train_t,
train_y=z,
)
quad_scheme = UnbiasedGaussLegendreQuad(
t_span[0], t_span[1], n_quad, quad_percent=quad_percent
)
likelihood = IndepGaussLikelihood(G, num_meas, measurement_noise)
if fast_initialization:
m, dmdt = marginal_sde.mean(train_t, return_grad=True)
else:
m, dmdt = (None, None)
sde_prior = ExactMotionModel(drift)
super().__init__(
marginal_sde,
likelihood,
quad_scheme,
sde_prior,
diffusion_prior,
n_reparam_samples,
)
[docs] def elbo(
self, t: Tensor, x: Tensor, beta: float, N: int, print_loss=False
) -> Tensor:
return super().elbo(t, x, beta, N, print_loss=print_loss)
[docs]class NeuralSDE(SDELearner):
"""
Subclass of SDELearner when the drift function is represented by a fully connected
neural network.
Args:
d (int): dimension of the state
t_span (Union[Tuple[float, float], Tuple[Tensor, Tensor]]): min and max boundary for RBF centers
n_reparam_samples (int): number of samples to use when using the reparametrization trick
G (Tensor): observation matrix (i.e. y = Gx)
drift_layer_description (List[int]): number of neurons in each layer of the drift function
nonlinearity (nn.Module): nonlinearity to use in the drift function
measurement_noise (Tensor): variance of observations
tau (float): global half-cauchy scaling parameter
train_t (Tensor, optional): observation time stamps (ns,)
train_x (Tensor, optional): observations (ns, num_meas)
n_quad (int, optional): number of quadrature nodes to use
quad_percent (float, optional): 1 - quad_percent of the quad_nodes will be sampled uniformly in the time window
n_tau (int, optional): number of centers for the RBF models
Q_diag (Tensor, optional): diagonal component of the diffusion matrix
"""
def __init__(
self,
d: int,
t_span: Union[Tuple[float, float], Tuple[Tensor, Tensor]],
n_reparam_samples: int,
G: Tensor,
drift_layer_description: List[int],
nonlinearity: nn.Module,
measurement_noise: Tensor,
tau: float = 1e-5,
train_t: Optional[Tensor] = None,
train_x: Optional[Tensor] = None,
n_quad: int = 128,
quad_percent: float = 0.8,
n_tau: int = 200,
Q_diag: Optional[Tensor] = None,
) -> None:
assert quad_percent > 0 and quad_percent < 1, "quad_percent must be in (0, 1)"
if Q_diag is None:
Q_diag = torch.ones(d) * 1e-0
diffusion_prior = SparseDiagonalDiffusionPrior(
d, Q_diag, n_reparam_samples, tau
)
else:
diffusion_prior = ConstantDiagonalDiffusionPrior(d, Q_diag)
fast_initialization = (
(torch.linalg.matrix_rank(G) == d)
and (train_t is not None)
and (train_x is not None)
)
if fast_initialization:
assert isinstance(train_x, Tensor) # helping the type checker
z = solve_least_squares(G, train_x.t(), gamma=1e-2).t()
else:
train_t = None
z = None
marginal_sde = DiagonalMarginalSDE(
d,
t_span,
diffusion_prior=diffusion_prior,
model_form="GLM",
n_tau=n_tau,
learn_inducing_locations=False,
train_x=train_t,
train_y=z,
)
quad_scheme = UnbiasedGaussLegendreQuad(
t_span[0], t_span[1], n_quad, quad_percent=quad_percent
)
likelihood = IndepGaussLikelihood(G, G.shape[0], measurement_noise)
sde_prior = DriftFCNN(d, drift_layer_description, nonlinearity)
super().__init__(
marginal_sde,
likelihood,
quad_scheme,
sde_prior,
diffusion_prior,
n_reparam_samples,
)