Source code for svise.sdeint

import torch
from torch import Tensor
import torchsde
from .sde_learning import SDELearner
from typing import Callable


class _SDE(torch.nn.Module):
    # noise_type = "general"
    noise_type = "additive"
    sde_type = "ito"

    def __init__(self, drift: Callable, diffusion: Callable):
        super().__init__()
        self.drift = drift
        self.diffusion = diffusion

    # Drift
    def f(self, t, y):
        # print(t)
        # shape (batch_size, state_size)
        return self.drift(t, y)

    # Diffusion
    def g(self, t, y):
        # shape (bastch_size, state_size, state_size)
        return self.diffusion()


[docs]def solve_sde(sde: SDELearner, x0: Tensor, t_eval: Tensor, **kwargs) -> Tensor: """Wraps torchsde.sdeint to solve the approximate posterior over the drift + diffusion function in the sde Args: sde (SDELearner): some SDE x0 (Tensor): intial conditions t_eval (Tensor): batch of time stamps **kwargs: kwargs passed onto the torchsde solver Returns: Tensor: sde solution samples """ sde_tmp = _SDE(sde.drift, sde.diffusion) # perhaps we should check if |sde.diffusion| << 1 and then use a standard ode solver # note I have dug into the source code, this will work fine even for SDEs with random coefficients xs = torchsde.sdeint(sde_tmp, x0, t_eval, **kwargs) return xs