SDE prior
This section contains some utilties for defining priors over the drift function in the prior over the state. All utilties defined after the SDEPrior base class inherit from this class.
SDEPrior base class
- class svise.sde_learning.SDEPrior[source]
Abstract base class for SDE priors. (i.e. any subclass that inherits from this class and appropriately implements all methods is compatible with the SDELearner class).
- abstract drift(t: Tensor, z: Tensor, integration_mode=False) Tensor[source]
Computes the derivatives for a given setting of the variational distribution weights
- abstract forward(t: Tensor, z: Tensor) Tensor[source]
Compute derivatives under the variational distribution, q(theta) for a batch of latent states. Often this function will call self.resample_weights() followed by self.forward(t,z)
- Parameters:
t (Tensor) – (bs, ) time stamps at which to compute the drift
z (Tensor) – (n_reparam_samples, bs, d) batch of latent states
- Returns:
(n_reparam_samples, bs, d) batch of drift function evaluations
- Return type:
Tensor
Exact motion model
- class svise.sde_learning.ExactMotionModel(f: Callable[[Tensor, Tensor], Tensor])[source]
A prior over the drift function for used in the case the exact form of the drift function is known (i.e. standard state estimation).
- Parameters:
f (Callable[[Tensor, Tensor], Tensor]) – exact drift function
- drift(t: Tensor, z: Tensor, integration_mode=False) Tensor[source]
Computes the derivatives for a given setting of the variational distribution weights
- forward(t: Tensor, z: Tensor) Tensor[source]
Compute derivatives under the variational distribution, q(theta) for a batch of latent states.
- Parameters:
t (Tensor) – (bs, ) time stamps at which to compute the drift
z (Tensor) – (n_reparam_samples, bs, d) batch of latent states
- Returns:
(n_reparam_samples, bs, d) batch of drift function evaluations
- Return type:
Tensor
Sparse linear model
- class svise.sde_learning.SparseMultioutputGLM(d: int, SparseFeatures: SparseFeaturesLibrary, n_reparam_samples: int, tau: float = 1e-05, train_x: Tensor | None = None, train_y: Tensor | None = None, resample_on_init: bool = True, transform: ScaleTransform | None = None)[source]
Sparse multioutput generalized linear model for the drift function.
- Parameters:
d (int) – Dimension of output.
SparseFeatures (SparseFeaturesLibrary) – Sparse features library.
n_reparam_samples (int) – Number of reparameterized samples.
tau (float, optional) – Prior on global scaling parameter. Defaults to 1e-5.
train_x (Tensor, optional) – Training states (for initialization)
train_y (Tensor, optional) – Training state derivatives (for intialization)
resample_on_init (bool, optional) – boolean indicating whether to resample weights after initialization
transform (ScaleTransform, optional) – NOT TESTED, DO NOT USE
- drift(t: Tensor, z: Tensor, integration_mode=False) Tensor[source]
Computes the derivatives for a given setting of the variational distribution weights
- forward(t: Tensor, z: Tensor) Tensor[source]
Compute derivatives under the variational distribution, q(theta) for a batch of latent states. Often this function will call self.resample_weights() followed by self.forward(t,z)
- Parameters:
t (Tensor) – (bs, ) time stamps at which to compute the drift
z (Tensor) – (n_reparam_samples, bs, d) batch of latent states
- Returns:
(n_reparam_samples, bs, d) batch of drift function evaluations
- Return type:
Tensor
Second order sparse linear model
- class svise.sde_learning.SparseIntegratorGLM(d: int, SparseFeatures: SparseFeaturesLibrary, n_reparam_samples: int, integrator_indices: List[int] | None = None, tau: float = 1e-05, train_x: Tensor | None = None, train_y: Tensor | None = None)[source]
Assumes the governing equations can be written in the form: d^2x/dt^2 = f(x, dx/dt), where f is a sparse linear combination of functions from the features library.
- Parameters:
d (int) – Dimension of output.
SparseFeatures (SparseFeaturesLibrary) – Sparse features library.
n_reparam_samples (int) – Number of reparameterized samples.
integrator_indices (List[int]) – which states correspond to unknown dynamics
tau (float, optional) – Prior on global scaling parameter. Defaults to 1e-5.
train_x (Tensor, optional) – Training states (for initialization)
train_y (Tensor, optional) – Training state derivatives (for intialization)
Sparse linear model of neighbours
- class svise.sde_learning.SparseNeighbourGLM(d: int, SparseFeatures: SparseFeaturesLibrary, n_reparam_samples: int, tau: float = 1e-05, train_x: Tensor | None = None, train_y: Tensor | None = None, resample_on_init: bool = True, transform: ScaleTransform | None = None)[source]
Sparse multioutput generalized linear model where the drift is assumed to be a function of its neighbours.
- Parameters:
d (int) – Dimension of output.
SparseFeatures (SparseFeaturesLibrary) – Sparse features library.
n_reparam_samples (int) – Number of reparameterized samples.
tau (float, optional) – Prior on global scaling parameter. Defaults to 1e-5.
train_x (Tensor, optional) – Training states (for initialization)
train_y (Tensor, optional) – Training state derivatives (for intialization)
resample_on_init (bool, optional) – boolean indicating whether to resample weights after initialization
transform (ScaleTransform, optional) – NOT TESTED, DO NOT USE
- drift(t: Tensor, z: Tensor, integration_mode=False) Tensor[source]
Computes the derivatives for a given setting of the variational distribution weights
- forward(t: Tensor, z: Tensor) Tensor[source]
Compute derivatives under the variational distribution, q(theta) for a batch of latent states. Often this function will call self.resample_weights() followed by self.forward(t,z)
- Parameters:
t (Tensor) – (bs, ) time stamps at which to compute the drift
z (Tensor) – (n_reparam_samples, bs, d) batch of latent states
- Returns:
(n_reparam_samples, bs, d) batch of drift function evaluations
- Return type:
Tensor