Source code for ceem.smoother

#
# File: smoother.py
#

import numpy as np
import torch
from numpy.random import choice
from scipy.optimize import least_squares, minimize
from torch.distributions.multivariate_normal import MultivariateNormal

from ceem import utils
from ceem.opt_criteria import GroupSOSCriterion, STRStateCriterion
from tqdm import tqdm


[docs]def NLSsmoother(x0, criterion, system, solver_kwargs={'verbose': 2}): """ Smoothing with Gauss-Newton based approach Args: x0 (torch.tensor): (1,T,n) system states criterion (SOSCriterion): criterion to optimize system (DiscreteDynamicalSystem) solver_kwargs : options for scipy.optimize.least_squares """ if 'tr_rho' in solver_kwargs: tr_rho = solver_kwargs.pop('tr_rho') criterion = GroupSOSCriterion([criterion, STRStateCriterion(tr_rho, x0)]) B, T, xdim = x0.shape assert B == 1, f"Smoothing one trajectory at a time. x0.shape[0] is {B} but should be 1." def loss(x): with torch.no_grad(): x = torch.tensor(x).view(1, T, xdim).to(x0.dtype) loss = criterion.residuals(system, x) return loss.numpy() def jac(x): x = torch.tensor(x).view(1, T, xdim) return criterion.jac_resid_x(system, x, sparse=True) with utils.Timer() as time: kwargs = dict(method='trf', loss='linear') kwargs.update(solver_kwargs) opt_result = least_squares(loss, x0.view(-1).detach().numpy(), jac, **kwargs) x = torch.tensor(opt_result.x).view(1, T, xdim) metrics = {'fun': float(opt_result.cost), 'success': opt_result.success, 'time': time.dt} return x, metrics
[docs]def EKF(x0, y_T, u_T, sigma0, Q, R, system): """ Extended Kalman filter Args: x0 (torch.tensor): (B, xdim) initial system states y_T (torch.tensor): (B, T, ydim) observations u_T (torch.tensor): (B, T, udim) controls sigma0 (torch.tensor): (xdim, xdim) initial state covariance dyn_err (torch.tensor): (xdim) dynamics error mean obs_err (torch.tensor): (xdim) observation error mean Q (torch.tensor): (xdim, xdim) dynamics error covariance R (torch.tensor): (ydim, ydim) observation error covariance system (DiscreteDynamicalSystem) Returns: x_filt (torch.tensor): (B, T, xdim) system states y_pred (torch.tensor): (B, T, ydim) predicted observations before state correction """ xdim = Q.shape[0] B, T, ydim = y_T.shape I = torch.eye(xdim) x = torch.zeros(B, T, xdim) x[:, 0:1] = x0 y = y_T.clone() with torch.no_grad(): y[:, 0:1] = system.observe(0, x[:, :1], u_T[:, :1]) St = torch.zeros(B, T, xdim, xdim) St[:, 0] = sigma0.unsqueeze(0) for t in tqdm(range(1, T)): # Propagate dynamics with torch.no_grad(): x[:, t:t + 1] = system.step(t - 1, x[:, t - 1:t], u_T[:, t - 1:t]) Gt = system.jac_step_x(t, x[:, t:t + 1], u_T[:, t:t + 1]).detach() St_hat = Gt @ St[:, t - 1:t] @ Gt.transpose(-1, -2) + Q # Estimate observation with torch.no_grad(): y[:, t:t + 1] = system.observe(t, x[:, t:t + 1], u_T[:, t:t + 1]) Ht = system.jac_obs_x(t, x[:, t:t + 1], u_T[:, t:t + 1]).detach() Zt = Ht @ St_hat @ Ht.transpose(-1, -2) + R # Estimate Kalman Gain and correct xt Kt = St_hat @ Ht.transpose(-1, -2) @ torch.inverse(Zt) x[:, t:t + 1] = x[:, t:t + 1] + (Kt @ ( y_T[:, t:t + 1] - y[:, t:t + 1]).unsqueeze(-1)).squeeze(-1) St[:, t:t + 1] = (I - Kt @ Ht) @ St_hat return x, y
## Particle Smoother
[docs]class ParticleSmootherSystemWrapper: def __init__(self, sys, R): self._sys = sys self._Rmvn = MultivariateNormal(torch.zeros(R.shape[0],), R) def __call__(self, x, t): """ t (int): time x (torch.tensor): (N,n) particles """ x = x.unsqueeze(1) t = torch.tensor([float(t)]) nx = self._sys.step(t, x) return nx.squeeze(1)
[docs] def obsll(self, x, y): """ x (torch.tensor): (N,n) particles y (torch.tensor): (1,m) observation """ y_ = self._sys.observe(None, x.unsqueeze(1)).squeeze(1) dy = y - y_ logprob_y = self._Rmvn.log_prob(dy).unsqueeze(1) return logprob_y
@property def _xdim(self): return self._sys.xdim
[docs]class ParticleSmoother: def __init__(self, N, system, obsll, Q, Px0, x0mean=None): """ Args: N (int): system (DiscreteDynamics): obsfun (callable): function mapping ((*,xdim),(1,ydim) -> (*,[0,1]) Q (torch.tensor): (xdim,xdim) torch tensor R (torch.tensor): (ydim,ydim) torch tensor Px0 (torch.tensor): (xdim,xdim) torch tensor x0mean (torch.tensor): (1,xdim) torch tensor """ self._N = N self._system = system self._obsll = obsll self._xdim = system._xdim self._Qchol = Q.cholesky().unsqueeze(0) self._Qpdf = MultivariateNormal(torch.zeros(self._xdim), Q) self._Px0chol = Px0.cholesky().unsqueeze(0) if x0mean is not None: self._x0mean = x0mean else: self._x0mean = torch.zeros(1, self._xdim) self._xfilt = None self._wfilt = None self._wsmooth = None
[docs] def filter(self, y): # inputs: # y (torch.tensor): (T, ydim) torch tensor T = y.shape[0] x = torch.zeros(T, self._N, self._xdim) w = torch.zeros(T, self._N, 1) # sample initial distribution x[0] = self._x0mean + (self._Px0chol @ torch.randn(self._N, self._xdim, 1)).squeeze(2) for t in range(T - 1): ## Observe log_wt = self._obsll(x[t], y[None,t]) ## Update weights # numerically stable computation of w log_wt -= log_wt.max() wt = log_wt.exp() wt /= wt.sum() # since we divide by wt.sum(), subtracting off log_wt.max() # gives the same result w[t] = wt ## Resample rinds = choice(self._N, self._N, p=w[t, :, 0].detach().numpy()) xtr = x[t,rinds] ## Propegate with torch.no_grad(): x[t + 1] = self._system( xtr, t) + (self._Qchol @ torch.randn(self._N, self._xdim, 1)).squeeze(2) log_wt = self._obsll(x[-1], y[None, -1]) log_wt -= log_wt.max() wt = log_wt.exp() wt /= wt.sum() w[-1] = wt return x, w
[docs] def smoother(self, x, w): T, N, n = x.shape ## Compute p(xt+1|xt) Tlogprobs = torch.zeros(T-1, N, N) for t in range(T - 1): with torch.no_grad(): xtp1_pred = self._system(x[t], t) xtp1_diff = xtp1_pred.unsqueeze(1) - x[None, t + 1] Tlogprobs[t] = self._Qpdf.log_prob(xtp1_diff.reshape(-1, n)).reshape(N, N) # for numerical stability subtract the max Tlogprobs -= Tlogprobs.max(1)[0].unsqueeze(1) Tprobs = Tlogprobs.exp() # compute v v = (w[:-1] * Tprobs).sum(1) # since Tprobs sum is in the denominator, subtracting Tlogprobs.max # above gives the same result as not # compute w_N by backward recursion w_N = w.clone() # sets w_N[-1] = w[-1] for t in range(T - 1): t = T - t - 2 w_N_t = w[t] * (w_N[t + 1] * Tprobs[t] / v[t].unsqueeze(1)).sum(1).unsqueeze(1) if w_N_t.sum() > 0.: # if no particles have weight just use the filtered weight w_N[t] = w_N_t # normalize weights w_N[t] /= w_N[t].sum() return w_N
[docs] def run(self, y): x, w = self.filter(y) w_N = self.smoother(x, w) self._xfilt = x self._wfilt = w self._wsmooth = w_N
[docs] def get_smooth_mean(self): x_mean = (self._xfilt * self._wsmooth).sum(1) / self._wsmooth.sum(1) return x_mean