Markov Gaussian process
This section contains some utilties for defining the approximate posterior over the state.
MarginalSDE base class
- class svise.sde_learning.MarginalSDE[source]
Abstract base class for a model of the marginal statistics of a Markov Gaussian process.
- abstract K(t: Tensor) Tensor[source]
Returns the the covariance matrix K(t) evaluated at a batch of times.
- Parameters:
t (Tensor) – Time stamps at which to compute the covariance matrix
- Returns:
(len(b), d, d) batch of covariance matrices
- Return type:
Tensor
- abstract forward(t: Tensor, f: Callable[[Tensor, Tensor], Tensor], num_samples: int) Tensor[source]
Computes the (unweighted) residual loss between the approximating and prior SDEs.
- Parameters:
t (Tensor) – (bs, ) time stamps at which to generate samples
f (Callable[[Tensor,Tensor],Tensor]) – (t, z) -> (num_samples, bs, d) or (bs, d) prior SDE evaluations
num_samples (int) – number of reparameterized samples at each time stamp
- Returns:
(bs,) approximate residual loss
- Return type:
Tensor
- abstract generate_samples(t: Tensor, num_samples: int, *args, **kwargs) Tensor[source]
Generates samples from the approximating SDE marginal distribution at time t and optionally returns some intermediate quantities.
- Parameters:
t (Tensor) – (bs, ) time stamps at which to generate samples
num_samples (int) – number of independent samples at each time stamp
- Returns:
samples of latent states
- Return type:
Tensor
Spectral parametrization
- class svise.sde_learning.SpectralMarginalSDE(d: int, t_span: Tuple[float, float] | Tuple[Tensor, Tensor], diffusion_prior: DiffusionPrior, model_form: str = 'GLM', vmap: bool = False, **kwargs)[source]
A model for the marginal statistics of a Markov GP using the spectral parametrization described in the main text.
- Parameters:
d (int) – dimension of the state
t_span (Union[Tuple[float, float], Tuple[Tensor, Tensor]]) – time span of data
diffusion_prior (DiffusionPrior) – Some diffusion prior
model_form (str, optional) – GLM or FCNN (FCNN not fully tested)
vmap (bool, optional) – DEPRECATED
**kwargs – arguments passed onto GLM / FCNN
- K(t: Tensor) Tensor[source]
Compute the marginal covariance matrix at time t.
- Parameters:
t (Tensor) – Time stamps at which to compute the covariance matrix
- Returns:
(len(b), d, d) batch of covariance matrices
- Return type:
Tensor
- drift(t: Tensor, z: Tensor) Tensor[source]
compute the drift function of the equivalent SDE at times t.
- Parameters:
t (Tensor) – (bs,) time stamps at which to compute the drift
z (Tensor) – (…, bs, d) batch of latent states
- Returns:
(…, bs, d) batch of drift function evaluations
- Return type:
Tensor
- forward(t: Tensor, f: Callable[[Tensor, Tensor], Tensor], num_samples: int) Tensor[source]
Computes the (unweighted) residual loss between the approximating and prior SDEs.
- Parameters:
t (Tensor) – (bs, ) time stamps at which to generate samples
f (Callable[[Tensor,Tensor],Tensor]) – (t, z) -> (num_samples, bs, d) or (bs, d) prior SDE evaluations
num_samples (int) – number of reparameterized samples at each time stamp
- Returns:
(bs,) approximate residual loss
- Return type:
Tensor
- generate_samples(t: Tensor, num_samples: int, return_intermediates: bool = False) Tensor | Tuple[Tensor, Tensor, Tensor, Tensor, Tensor][source]
Generates samples from the approximating SDE marginal distribution at time t and optionally returns some intermediate quantities.
- Parameters:
t (Tensor) – (bs, ) time stamps at which to generate samples
num_samples (int) – number of independent samples at each time stamp
return_intermediates (bool, optional) – indicates whether to return intermediate quanties. Defaults to False.
- Returns:
samples of latent states or samples of latent states and intermediate quantities
- Return type:
Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]
- mean(t: Tensor, return_grad=False) Tensor[source]
Returns the marginal mean at time t
- Parameters:
t (Tensor) – (bs,) time stamps at which to evaluate mean
- Returns:
(bs, d) mean evaluated at times
- Return type:
Tensor
- mean_parameters() Iterator[Parameter][source]
Returns an iterator over the mean parameters of the marginal SDE.
- Returns:
An iterator over the mean parameters of the marginal SDE.
- Return type:
Iterator[Parameter]
- unweighted_residual_loss(drift: Tensor, f: Tensor) Tensor[source]
Computes the (unweighted) residual loss between the approximating and prior SDEs.
- Parameters:
drift (Tensor) – (n_samples, bs, d) drift function evaluations
f (Tensor) – (n_samples, bs, d) / (bs, d) prior SDE evaluations
- Returns:
residual loss
- Return type:
Tensor
Diagonal parametrization
- class svise.sde_learning.DiagonalMarginalSDE(d: int, t_span: Tuple[float, float] | Tuple[Tensor, Tensor], diffusion_prior: DiagonalDiffusionPrior, model_form: str = 'GLM', vmap: bool = False, **kwargs)[source]
A model for the marginal statistics of a Markov GP whose covariance is diagonal.
- Parameters:
d (int) – dimension of the state
t_span (Union[Tuple[float, float], Tuple[Tensor, Tensor]]) – time span of data
diffusion_prior (DiagonalDiffusionPrior) – Some diagonal diffusion prior
model_form (str, optional) – GLM or FCNN (FCNN not fully tested)
vmap (bool, optional) – DEPRECATED
**kwargs – arguments passed onto GLM / FCNN
- K(t: Tensor) Tensor[source]
Compute the marginal covariance matrix at time t.
- Parameters:
t (Tensor) – Time stamps at which to compute the covariance matrix
- Returns:
(len(b), d, d) batch of covariance matrices
- Return type:
Tensor
- drift(t: Tensor, z: Tensor) Tensor[source]
compute the drift function of the equivalent SDE at times t.
- Parameters:
t (Tensor) – (bs,) time stamps at which to compute the drift
z (Tensor) – (…, bs, d) batch of latent states
- Returns:
(…, bs, d) batch of drift function evaluations
- Return type:
Tensor
- forward(t: Tensor, f: Callable[[Tensor, Tensor], Tensor], num_samples: int) Tensor[source]
Computes the (unweighted) residual loss between the approximating and prior SDEs.
- Parameters:
t (Tensor) – (bs, ) time stamps at which to generate samples
f (Callable[[Tensor,Tensor],Tensor]) – (t, z) -> (num_samples, bs, d) or (bs, d) prior SDE evaluations
num_samples (int) – number of reparameterized samples at each time stamp
- Returns:
(bs,) approximate residual loss
- Return type:
Tensor
- generate_samples(t: Tensor, num_samples: int) Tensor[source]
Generates samples from the approximating SDE marginal distribution at time t and optionally returns some intermediate quantities.
- Parameters:
t (Tensor) – (bs, ) time stamps at which to generate samples
num_samples (int) – number of independent samples at each time stamp
return_intermediates (bool, optional) – indicates whether to return intermediate quanties. Defaults to False.
- Returns:
samples of latent states or samples of latent states and intermediate quantities
- Return type:
Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]
- mean(t: Tensor, return_grad=False) Tensor[source]
Returns the marginal mean at time t
- Parameters:
t (Tensor) – (bs,) time stamps at which to evaluate mean
- Returns:
(bs, d) mean evaluated at times
- Return type:
Tensor
- mean_parameters() Iterator[Parameter][source]
Returns an iterator over the mean parameters of the marginal SDE.
- Returns:
An iterator over the mean parameters of the marginal SDE.
- Return type:
Iterator[Parameter]
- unweighted_residual_loss(drift: Tensor, f: Tensor) Tensor[source]
Computes the (unweighted) residual loss between the approximating and prior SDEs.
- Parameters:
drift (Tensor) – (n_samples, bs, d) drift function evaluations
f (Tensor) – (n_samples, bs, d) / (bs, d) prior SDE evaluations
- Returns:
residual loss
- Return type:
Tensor