#
# File: smoother.py
#
import numpy as np
import torch
from numpy.random import choice
from scipy.optimize import least_squares, minimize
from torch.distributions.multivariate_normal import MultivariateNormal
from ceem import utils
from ceem.opt_criteria import GroupSOSCriterion, STRStateCriterion
from tqdm import tqdm
[docs]def NLSsmoother(x0, criterion, system, solver_kwargs={'verbose': 2}):
    """
    Smoothing with Gauss-Newton based approach
    Args:
        x0 (torch.tensor): (1,T,n) system states
        criterion (SOSCriterion): criterion to optimize
        system (DiscreteDynamicalSystem)
        solver_kwargs : options for scipy.optimize.least_squares
    """
    if 'tr_rho' in solver_kwargs:
        tr_rho = solver_kwargs.pop('tr_rho')
        criterion = GroupSOSCriterion([criterion, STRStateCriterion(tr_rho, x0)])
    B, T, xdim = x0.shape
    assert B == 1, f"Smoothing one trajectory at a time. x0.shape[0] is {B} but should be 1."
    def loss(x):
        with torch.no_grad():
            x = torch.tensor(x).view(1, T, xdim).to(x0.dtype)
            loss = criterion.residuals(system, x)
        return loss.numpy()
    def jac(x):
        x = torch.tensor(x).view(1, T, xdim)
        return criterion.jac_resid_x(system, x, sparse=True)
    with utils.Timer() as time:
        kwargs = dict(method='trf', loss='linear')
        kwargs.update(solver_kwargs)
        opt_result = least_squares(loss, x0.view(-1).detach().numpy(), jac, **kwargs)
    x = torch.tensor(opt_result.x).view(1, T, xdim)
    metrics = {'fun': float(opt_result.cost), 'success': opt_result.success, 'time': time.dt}
    return x, metrics 
[docs]def EKF(x0, y_T, u_T, sigma0, Q, R, system):
    """
    Extended Kalman filter
    Args:
        x0 (torch.tensor): (B, xdim) initial system states
        y_T (torch.tensor): (B, T, ydim) observations
        u_T (torch.tensor): (B, T, udim) controls
        sigma0 (torch.tensor): (xdim, xdim) initial state covariance
        dyn_err (torch.tensor): (xdim) dynamics error mean
        obs_err (torch.tensor): (xdim) observation error mean
        Q (torch.tensor): (xdim, xdim) dynamics error covariance
        R (torch.tensor): (ydim, ydim) observation error covariance
        system (DiscreteDynamicalSystem)
    Returns:
        x_filt (torch.tensor): (B, T, xdim) system states
        y_pred (torch.tensor): (B, T, ydim) predicted observations before state correction
    """
    xdim = Q.shape[0]
    B, T, ydim = y_T.shape
    I = torch.eye(xdim)
    x = torch.zeros(B, T, xdim)
    x[:, 0:1] = x0
    y = y_T.clone()
    with torch.no_grad():
        y[:, 0:1] = system.observe(0, x[:, :1], u_T[:, :1])
    St = torch.zeros(B, T, xdim, xdim)
    St[:, 0] = sigma0.unsqueeze(0)
    for t in tqdm(range(1, T)):
        # Propagate dynamics
        with torch.no_grad():
            x[:, t:t + 1] = system.step(t - 1, x[:, t - 1:t], u_T[:, t - 1:t])
        Gt = system.jac_step_x(t, x[:, t:t + 1], u_T[:, t:t + 1]).detach()
        St_hat = Gt @ St[:, t - 1:t] @ Gt.transpose(-1, -2) + Q
        # Estimate observation
        with torch.no_grad():
            y[:, t:t + 1] = system.observe(t, x[:, t:t + 1], u_T[:, t:t + 1])
        Ht = system.jac_obs_x(t, x[:, t:t + 1], u_T[:, t:t + 1]).detach()
        Zt = Ht @ St_hat @ Ht.transpose(-1, -2) + R
        # Estimate Kalman Gain and correct xt
        Kt = St_hat @ Ht.transpose(-1, -2) @ torch.inverse(Zt)
        x[:, t:t + 1] = x[:, t:t + 1] + (Kt @ (
            y_T[:, t:t + 1] - y[:, t:t + 1]).unsqueeze(-1)).squeeze(-1)
        St[:, t:t + 1] = (I - Kt @ Ht) @ St_hat
    return x, y 
## Particle Smoother
[docs]class ParticleSmootherSystemWrapper:
    def __init__(self, sys, R):
        self._sys = sys
        self._Rmvn = MultivariateNormal(torch.zeros(R.shape[0],), R)
    def __call__(self, x, t):
        """
        t (int): time
        x (torch.tensor): (N,n) particles
        """
        x = x.unsqueeze(1)
        t = torch.tensor([float(t)])
        nx = self._sys.step(t, x)
        return nx.squeeze(1)
[docs]    def obsll(self, x, y):
        """
        x (torch.tensor): (N,n) particles
        y (torch.tensor): (1,m) observation
        """
        y_ = self._sys.observe(None, x.unsqueeze(1)).squeeze(1)
        dy = y - y_
        logprob_y = self._Rmvn.log_prob(dy).unsqueeze(1)
        return logprob_y 
    @property
    def _xdim(self):
        return self._sys.xdim 
[docs]class ParticleSmoother:
    def __init__(self, N, system, obsll, Q, Px0, x0mean=None):
        """
        Args:
          N (int):
          system (DiscreteDynamics):
          obsfun (callable): function mapping ((*,xdim),(1,ydim) -> (*,[0,1])
          Q (torch.tensor): (xdim,xdim) torch tensor
          R (torch.tensor): (ydim,ydim) torch tensor
          Px0 (torch.tensor): (xdim,xdim) torch tensor
          x0mean (torch.tensor): (1,xdim) torch tensor
        """
        self._N = N
        self._system = system
        self._obsll = obsll
        self._xdim = system._xdim
        self._Qchol = Q.cholesky().unsqueeze(0)
        self._Qpdf = MultivariateNormal(torch.zeros(self._xdim), Q)
        self._Px0chol = Px0.cholesky().unsqueeze(0)
        if x0mean is not None:
            self._x0mean = x0mean
        else:
            self._x0mean = torch.zeros(1, self._xdim)
        self._xfilt = None
        self._wfilt = None
        self._wsmooth = None
[docs]    def filter(self, y):
        # inputs:
        #   y (torch.tensor): (T, ydim) torch tensor
        T = y.shape[0]
        x = torch.zeros(T, self._N, self._xdim)
        w = torch.zeros(T, self._N, 1)
        # sample initial distribution
        x[0] = self._x0mean + (self._Px0chol @ torch.randn(self._N, self._xdim, 1)).squeeze(2)
        for t in range(T - 1):
            ## Observe
            log_wt = self._obsll(x[t], y[None,t])
            
            ## Update weights
            # numerically stable computation of w
            log_wt -= log_wt.max()
            wt = log_wt.exp()
            wt /= wt.sum()
            # since we divide by wt.sum(), subtracting off log_wt.max()
            # gives the same result
            w[t] = wt
            ## Resample
            rinds = choice(self._N, self._N, p=w[t, :, 0].detach().numpy())
            xtr = x[t,rinds]
            
            ## Propegate
            with torch.no_grad():
                x[t + 1] = self._system(
                    xtr, t) + (self._Qchol @ torch.randn(self._N, self._xdim, 1)).squeeze(2)
        log_wt = self._obsll(x[-1], y[None, -1])
        log_wt -= log_wt.max()
        wt = log_wt.exp()
        wt /= wt.sum()
        w[-1] = wt
        return x, w 
[docs]    def smoother(self, x, w):
        T, N, n = x.shape
        ## Compute p(xt+1|xt)
        Tlogprobs = torch.zeros(T-1, N, N)
        for t in range(T - 1):
            with torch.no_grad():
                xtp1_pred = self._system(x[t], t)
            xtp1_diff = xtp1_pred.unsqueeze(1) - x[None, t + 1]
            Tlogprobs[t] = self._Qpdf.log_prob(xtp1_diff.reshape(-1, n)).reshape(N, N)
	
	# for numerical stability subtract the max
        Tlogprobs -= Tlogprobs.max(1)[0].unsqueeze(1)
        Tprobs = Tlogprobs.exp()
        # compute v
        v = (w[:-1] * Tprobs).sum(1)
        # since Tprobs sum is in the denominator, subtracting Tlogprobs.max
        # above gives the same result as not
        # compute w_N by backward recursion
        w_N = w.clone()
        #  sets w_N[-1] = w[-1]
        for t in range(T - 1):
            t = T - t - 2
            w_N_t = w[t] * (w_N[t + 1] * Tprobs[t] / v[t].unsqueeze(1)).sum(1).unsqueeze(1)
            if w_N_t.sum() > 0.:
                # if no particles have weight just use the filtered weight
                w_N[t] = w_N_t
            # normalize weights
            w_N[t] /= w_N[t].sum()
        return w_N 
[docs]    def run(self, y):
        x, w = self.filter(y)
        w_N = self.smoother(x, w)
        self._xfilt = x
        self._wfilt = w
        self._wsmooth = w_N 
[docs]    def get_smooth_mean(self):
        x_mean = (self._xfilt * self._wsmooth).sum(1) / self._wsmooth.sum(1)
        return x_mean