Source code for svise.quadrature._unbiased_quadrature

import torch
from torch import Tensor
import torch.nn as nn
from ._quadrature import *
from ._barycentric_interp import BarycentricInterpolate
from typing import Callable, Union
from abc import ABC, abstractmethod
import warnings


def stratified_sample1d(
    t_0: Union[float, Tensor],
    t_1: Union[float, Tensor],
    M: int,
    L: int,
    dtype,
    device="cpu",
) -> Tensor:
    """Stratified sampling in 1D
        dtype ([TODO:parameter]): [TODO:description]
        device ([TODO:parameter]): [TODO:description]
        t_0 (Union[float, Tensor]): lower bound
        t_1 (Union[float, Tensor]): upper bound
        M (int): number of intervals
        L (int): number of samples / interval

    Returns:
        Tensor: (L, M) stratified samples
    """
    t = torch.linspace(t_0, t_1, M + 1, dtype=dtype, device=device)
    # compute interval width
    w = t[1] - t[0]
    samples = torch.rand(L, M, dtype=dtype, device=device) * w + t[:-1]
    return samples.flatten()


[docs]class QuadRule1D(ABC, nn.Module): """Abstract base class for 1D quad rules. Args: a (Union[float, Tensor]): lower bound of integration b (Union[float, Tensor]): upper bound of integration N (int): number of quadrature nodes """ @abstractmethod def __init__( self, a: Union[float, Tensor], b: Union[float, Tensor], N: int, *args, **kwargs ) -> None: super().__init__() self.dtype = torch.get_default_dtype() # saving some constants self.a = a self.b = b self.N = N
[docs] @abstractmethod def forward(self, f: Callable) -> Tensor: """Estimate for integral Args: f (Callable): function to be integrated Returns: Tensor: Iq integral estimate """ pass
class MonteCarloQuad(QuadRule1D): def __init__(self, a: Union[float, Tensor], b: Union[float, Tensor], N: int): super().__init__(a, b, N) def forward(self, f: Callable): xmc = torch.rand(self.N, dtype=self.dtype) * (self.b - self.a) + self.a dtype = xmc.dtype # evaluating function at quadrature nodes + monte-carlo fmc = f(xmc) Imc = fmc.mean() * (self.b - self.a) return Imc class StartifiedSampling(QuadRule1D): """Estimates integral with stratified sampling Args: a (Union[float, Tensor]): lower bound of integration b (Union[float, Tensor]): upper bound of integration N (int): number of quadrature nodes """ def __init__( self, a: Union[float, Tensor], b: Union[float, Tensor], N: int ) -> None: super().__init__(a, b, N) def forward(self, f: Callable): # evaluating function at quadrature nodes + monte-carlo # xstrat = self.stratified_sample1d(self.a, self.b, self.N, 1) xstrat = stratified_sample1d(self.a, self.b, self.N, 1, dtype=self.dtype) Istrat = f(xstrat).mean() * (self.b - self.a) # estimating integral return Istrat
[docs]class GaussLegendreQuad(QuadRule1D): """Estimates integral with Gauss Legendre quad. Args: a (Union[float, Tensor]): lower bound of integration b (Union[float, Tensor]): upper bound of integration N (int): number of quadrature nodes """ def __init__(self, a: Union[float, Tensor], b: Union[float, Tensor], N: int): super().__init__(a, b, N) wq, xq = gauss_legendre_vecs(N, a, b, dtype=self.dtype) self.register_buffer("wq", wq) self.register_buffer("xq", xq)
[docs] def forward(self, f: Callable): # evaluating function at quadrature nodes + monte-carlo fxq = f(self.xq) Iq = (fxq @ self.wq).mean() # estimating integral return Iq
[docs]class UnbiasedGaussLegendreQuad(QuadRule1D): """Estimates integral using the unbiased quadrature scheme based on Gauss Legendre quadrature from L.-f. Lee, Interpolation, quadrature, and stochastic integration, Econometric Theory 17, (2001). Args: a (Union[float, Tensor]): lower bound of integration b (Union[float, Tensor]): upper bound of integration N (int): number of quadrature nodes quad_percent (float): percentage of quad nodes to use for Legendre """ def __init__( self, a: Union[float, Tensor], b: Union[float, Tensor], N: int, quad_percent: float, gamma: Union[float, Tensor] = 1.0, ): super().__init__(a, b, N) err_msg = "Percent of quadrature points must be between 0 and 1" assert quad_percent <= 1.0, err_msg assert quad_percent >= 0.0, err_msg # calculating some convenient quantities nq = int(N * quad_percent) nmc = N - nq self.nq = nq self.nmc = nmc wq, xq = gauss_legendre_vecs(nq, a, b, dtype=self.dtype) self.register_buffer("gamma", torch.tensor(gamma)) self.register_buffer("wq", wq) self.register_buffer("xq", xq)
[docs] def forward(self, f: Callable): # xmc = torch.rand(self.nmc, dtype=self.dtype) * (self.b - self.a) + self.a xmc = stratified_sample1d( self.a, self.b, self.nmc, 1, dtype=self.dtype, device=self.xq.device ) dtype = xmc.dtype # evaluating function at quadrature nodes + monte-carlo fxq = f(self.xq) fmc = f(xmc) Iq = (fxq @ self.wq).mean() Imc = fmc.mean() * (self.b - self.a) # interpolating interp = BarycentricInterpolate(self.xq, fxq) fip = interp(xmc) # .to(dtype) divisor = 1 while torch.isnan(fip).any(): warnings.warn("Interp is nan, reducing interp degree.") divisor *= 2 interp = BarycentricInterpolate(self.xq[::divisor], fxq[::divisor]) fip = interp(xmc).to(dtype) Iip = fip.mean() * (self.b - self.a) # estimating integral return Imc - self.gamma * (Iip - Iq)