import torch
from ceem import utils
from torch.distributions.categorical import Categorical
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from scipy.optimize import minimize
from ceem import utils
import timeit
[docs]class Resampler:
    def __init__(self, method):
        '''
        Class containing Multinomial, Systematic, Stratified resamplers
        Args:
            method (str): 'multinomial', 'stratified', or 'systematic'
        '''
        if method not in ['multinomial', 'stratified', 'systematic']:
            raise NotImplementedError
        else:
            self._method = method
    def __call__(self, weights):
        '''
        Resampling method
        Args:
            weights (torch.tensor): (B, N) particle weights
        Returns:
            indices (torch.tensor): (B, N) particle indices
        '''
        return getattr(self, self._method)(weights)
[docs]    def dist(self, weights):
        return Categorical(weights) 
[docs]    def icdf(self, weights, cds):
        '''
        Linear-time inverse CDF for a categorical distribution
        Args:
            weights (torch.tensor): (B, N) particle weights
            cds (torch.tensor): (B,N) ascending ordered cumulative densities
        Returns:
            indices (torch.tensor): (B, N) particle indices
        '''
        B,N = weights.shape
        indices = torch.zeros(B,N, dtype=torch.long)
        for b in range(B):
            i = 0
            cumden = weights[b,i]
            prev_cds = -1.0
            for n in range(N):
                assert prev_cds < cds[b,n], 'cds must be ascending ordered.'
                prev_cds = cds[b,n]
                while cds[b,n] > cumden:
                    i += 1
                    cumden += weights[b,i]
                indices[b,n] = i
        return indices 
[docs]    def multinomial(self, weights):
        N = weights.shape[1]
        dist = self.dist(weights)
        return dist.sample((N,)).T 
[docs]    def stratified(self, weights):
        B,N = weights.shape
        cds = (torch.rand(B,N) + torch.arange(N, dtype=torch.get_default_dtype()).unsqueeze(0))*(1./N)
        print('Stratified')
        print(cds)
        return self.icdf(weights, cds) 
[docs]    def systematic(self, weights):
        B,N = weights.shape
        cds = (torch.rand(B,1) + torch.arange(N, dtype=torch.get_default_dtype()).unsqueeze(0))*(1./N)
        return self.icdf(weights, cds)  
[docs]class faPF:
    def __init__(self, N, system, Q, R, Px0, 
        x0mean=None, resampler=Resampler('systematic'), 
        FFBSi_N = None,
        burnin_T = None,
        ess_frac = 0.5):
        """
        Fully-Adapted Particle Filter
        Args:
          N (int):
          system (DiscreteDynamics):
          Q (torch.tensor): (xdim,xdim) torch tensor
          R (torch.tensor): (ydim,ydim) torch tensor
          Px0 (torch.tensor): (xdim,xdim) torch tensor
          x0mean (torch.tensor): (1,xdim) torch tensor
          resampler (Resampler): resampler
          FFBSi_N (int): number of backward trajectories to sample
          burnin_T (int or None): time skipped when computing MCEM Q
          ess_frac (float): Effective Sample Size fraction
        """
        self._N = N
        self._system = system
        self._xdim = system._xdim
        self._ydim = system._ydim
        self._Q = Q.unsqueeze(0)
        self._Qinv = Q.inverse().unsqueeze(0)
        self._Qdist = MultivariateNormal(torch.zeros(self._xdim), Q)
        self._x0mean = x0mean if x0mean is not None else torch.zeros(1, self._xdim)
        self._Px0 = Px0.unsqueeze(0)
        self._Px0dist = MultivariateNormal(self._x0mean, Px0)
        self._R = R.unsqueeze(0)
        self._Rinv = R.inverse().unsqueeze(0)
        self._Rdist = MultivariateNormal(torch.zeros(self._ydim), R)
        self._resampler = resampler
        self._FFBSi_N = FFBSi_N if FFBSi_N else max(N//10,1)
        self._burnin_T = burnin_T
        self._ess_frac = ess_frac
[docs]    def filter(self, y):
        '''
        Run filter
        Args:
            y (torch.tensor): (B,T,m) observartions
        Returns:
            x (torch.tenosr): (B,N,T,n) filtered states
            w (torch.tensor): (B,N,T) filtered weights
        '''
        N = self._N
        B,T,m = y.shape
        x = torch.zeros(B, N, T,  self._xdim)
        xr = torch.zeros_like(x)
        w = torch.ones(B, N, T, dtype=torch.get_default_dtype())/N # (1/N for faPF)  
        # sample initial distribution from p(x0 | y0)
        x[:,:,0:1] = self.sample_initial(y[:,0:1], N)
        v_unnorm = torch.zeros(B,N,T-1)
        log_prev_v = torch.ones(B,N, dtype=torch.get_default_dtype())
        for t in range(1,T):
            # sample x_t from p(x_t | x_t-1, y_t)
            xtm1 = x[:,:,t-1:t]
            yt = y[:,t:t+1]
            # resample
            resampdist = self.get_resampling_distribution(t-1, xtm1)
            log_vt = resampdist.log_prob(yt.repeat(1,N,1).view(B*N,m)).view(B,N)
           
            v_unnorm[:,:,t-1] = log_vt.exp()
            
            log_vt += log_prev_v
            log_vt -= log_vt.max(dim=1)[0].unsqueeze(1)
            vt = log_vt.exp()
            vt /= vt.sum(dim=1).unsqueeze(1)
            inds = self._resampler(vt.clone())
            xtm1r = xtm1.clone()
            for b in range(B):
                # adaptive
                ESS = 1./(vt[b]**2).sum()
                if ESS < self._ess_frac * N:
                    xtm1r[b] = xtm1[b, inds[b]]
                    log_prev_v[b] = torch.ones(N, dtype=torch.get_default_dtype())
                    # print('Resample at t=%d, ESS=%.3f'%(t,ESS))
                else:
                    log_prev_v[b] = log_vt[b]
            xr[:,:,t-1:t] = xtm1r
            # propegate
            sampdist = self.get_sampling_distribution(t-1, xtm1r, yt)
            xt = sampdist.sample().reshape(B,N,1,self._xdim)
            x[:,:,t:t+1] = xt
        mean_ll = v_unnorm.mean(1).log().mean()
        return x, xr, w, mean_ll 
[docs]    def sample_initial(self, y0, N):
        '''
        Sample from p(x0 | y0) 
        Args:
            y0 (torch.tensor): (B,1,m) initial observation
        Returns:
            x0 (torch.tensor): (B,N,1,n) samples of x0
        Notes: 
            Eqns 352-354 https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
        '''
        B, _, m = y0.shape
        y0 = y0.squeeze(1)
        mu_x = self._x0mean.repeat(B,1)
        tinp = torch.tensor([0] * B).unsqueeze(1).to(dtype=torch.get_default_dtype())
        C = self._system.jac_obs_x(tinp, mu_x.unsqueeze(1)).detach().squeeze(1)
        Sig_c = (C @ self._Px0).transpose(-2,-1)
        Sig_y = self._R + C @ self._Px0 @ C.transpose(-2,-1)
        muhat_x = mu_x + (Sig_c @ torch.solve(
            y0.unsqueeze(-1) - C @ mu_x.unsqueeze(-1), Sig_y)[0]).squeeze(-1)
        Sighat_x = self._Px0 - Sig_c @ torch.solve(Sig_c.transpose(-2,-1), Sig_y)[0]
        x0dist = MultivariateNormal(muhat_x, Sighat_x)
        # xtest = self._Px0dist.sample((100,))
        # ll_y_x = self._Rdist.log_prob((C @ xtest.transpose(-2,-1)).squeeze(-1) - y0).unsqueeze(-1)
        # ll_x = self._Px0dist.log_prob(xtest)
        # joint_ll = ll_y_x + ll_x
        # ll_x_y = x0dist.log_prob(xtest)
        # diff = ll_x_y - joint_ll
        # above diff is constant - implying this math is correct
        return x0dist.sample((N,)).transpose(0,1).unsqueeze(2) 
[docs]    def get_sampling_distribution(self, t, xtm1, yt):
        '''
        Sampling distribution p(xt | xt-1, yt)
        Args:
            t (int): time-index
            xtm1 (torch.tensor): (B,N,1,n) state-particles
            yt (torch.tensor): (B,1,m) observation 
        Returns:
            dist (MultivariateGaussian): p(xt | xt-1, yt)
        Notes:
            Implementation of
            https://link.springer.com/content/pdf/10.1023%2FA%3A1008935410038.pdf
            Equation 18-20
        '''
        B,N,_,xdim = xtm1.shape
        m = yt.shape[-1]
        tinp = torch.tensor([t] * B).unsqueeze(1).to(dtype=torch.get_default_dtype())
        fxtm1 = self._system.step(tinp-1, xtm1.view(B*N,1,xdim))
        Ht = self._system.jac_obs_x(t, fxtm1).detach().squeeze(1)
        HtT = Ht.transpose(-2,-1)
        yt_exp = self._system.observe(tinp, fxtm1)
        fxtm1 = fxtm1.squeeze(1)
        Sigtinv = self._Qinv + HtT @ self._Rinv @ Ht
        Sig = Sigtinv.inverse()
        yt_ = yt.repeat(1,N,1).view(B*N,m,1) 
        yt_ += Ht @ fxtm1.unsqueeze(-1) - yt_exp.transpose(-2,-1)
        mt = Sig @ (self._Qinv @ fxtm1.unsqueeze(-1) + HtT @ self._Rinv @ yt_)
        mt = mt.squeeze(-1)
        return MultivariateNormal(mt, Sig) 
[docs]    def get_resampling_distribution(self, t, xtm1):
        '''
        Resampling distribution p(yt | xt-1)
        Args:
            t (int): time-index
            xtm1 (torch.tensor): (B,N,1,n) state-particles
        Returns:
            dist (MultivariateGaussian): p(yt | xt-1)
        Notes:
            Implementation of
            https://link.springer.com/content/pdf/10.1023%2FA%3A1008935410038.pdf
            Equation 14
        '''
        ydim = self._system.ydim
        B,N,_,xdim = xtm1.shape
        tinp = torch.tensor([t] * B).unsqueeze(1).to(dtype=torch.get_default_dtype())
        fxtm1 = self._system.step(tinp-1, xtm1.view(B*N,1,xdim))
        Ht = self._system.jac_obs_x(t, fxtm1).detach().squeeze(1)
        HtT = Ht.transpose(-2,-1)
        yt = self._system.observe(tinp, fxtm1)
        mt = yt.squeeze(1)
        Sig = (self._R + Ht @ self._Q @ HtT)
        return MultivariateNormal(mt, Sig) 
    '''
    Smoothing Methods
    '''
[docs]    def FFBSm(self, y, x, w, return_wij_tN=False):
        '''
        Compute the smoothing marginal distributions
        Args:
            y (torch.tensor): (B,T,m) observartions
            x (torch.tenosr): (B,N,T,n) filtered states
            w (torch.tensor): (B,N,T) filtered weights
            return_wij_tN (bool): return pair-wise smoothing weights
        Returns:
            xsm (torch.tenosr): (B,N,T,n) smoothed states
            wsm (torch.tensor): (B,N,T) smoothed weights
            wij_tN (torch.tensor): (B,N,N,T) pairwise smoothing weights
        Notes:
            Implemenation of Forward Filtering-Backward Smoothing
            https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf
            Equation 49
        '''
        B,N,T,n = x.shape
        wsm = w.clone()
        wsm /= wsm.sum(1).unsqueeze(1) # normalize
        wij_tN = torch.zeros(B,N,N,T-1)
        for t in range(T-1):
            tau = T - t - 2
            # compute p(X_tau+1 | X_tau)
            tinp = torch.tensor([t] * B).unsqueeze(1).to(dtype=torch.get_default_dtype())
            fxtau = self._system.step(tinp, x[:,:,tau])
            dxtaup1 = x[:,:,tau+1].unsqueeze(2) - fxtau.unsqueeze(1) # (B,N,N,n)
            fprobs = self._Qdist.log_prob(dxtaup1.view(B*N*N,n)).view(B,N,N).exp()
            den = (fprobs * w[:,:,tau:tau+1]).sum(-2).unsqueeze(-2)
            
            wij_tN[:,:,:,tau] = wsm[:,:,tau].unsqueeze(-1) * (wsm[:,:,tau+1:tau+2].transpose(-2,-1) * fprobs / den)
            
            wsm[:,:,tau] = wij_tN[:,:,:,tau].sum(-1)
            wsm[:,:,tau] /= wsm[:,:,tau].sum(1).unsqueeze(1) # normalize
            wij_tN[:,:,:,tau] /= wij_tN[:,:,:,tau].sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1)
        xsm = x.clone()
        if return_wij_tN:
            return xsm, wsm, wij_tN
        else:
            return xsm, wsm 
[docs]    def FFBSi(self, x, w):
        '''
        Sample trajectories from JSD using FFBSi
        Args:
            x (torch.tenosr): (B,N,T,n) filtered states
            w (torch.tensor): (B,N,T) filtered weights
        Returns:
            xsamps (torch.tesnor): (B, Ns, T, n) sampled state trajs
        Notes: 
            See http://users.isy.liu.se/rt/schon/Publications/LindstenS2013.pdf
            Algorithm 4
        '''
        B,N,T,n = x.shape
        Ns = self._FFBSi_N
        xsamps = torch.zeros((B, Ns, T, n))
        # Sample at T
        bj = Categorical(w[:,:,-1]).sample((Ns,)).T
        for b in range(B):
            xsamps[b,:,-1] = x[b, bj[b],-1]
        for tau in range(T-1):
            t = T - tau - 2
            tinp = torch.tensor([t] * B).unsqueeze(1).to(dtype=torch.get_default_dtype())
            fxt = self._system.step(tinp, x[:,:,t])
            dxtp1 = xsamps[:,:,t+1].unsqueeze(2) - fxt.unsqueeze(1) # (B,Ns,N,n)
            flgprobs = self._Qdist.log_prob(dxtp1.view(B*Ns*N,n)).view(B,Ns,N)
            wlgfprobs = flgprobs + w[:,:,t].unsqueeze(1).log()
            bj = Categorical(logits=flgprobs).sample()
            for b in range(B):
                xsamps[b,:,t] = x[b, bj[b],t]
        return xsamps 
    '''
    Methods for computing Q(theta | theta_k)
    '''
[docs]    def Q_MCEM(self, y, x):
        '''
        Compute Monte-Carlo appx Q(theta, theta_k) 
                using iid x_1:T ~ p(x_1:T|y_1:T)
        Args:
            y (torch.tensor): (B,T,m) observartions
            x (torch.tenosr): (B,N,T,n) sampled state trajs
        Returns:
            Q (torch.tensor): (1,) Q(theta, theta_k)
        '''
        B,N,T,n = x.shape
        m = self._system.ydim
        burnin_T = self._burnin_T if self._burnin_T is not None else T//10
        Ttil = T - burnin_T
        x = x.detach()
        y = y.detach()
        # Compute I1
        if burnin_T == 0:
            I1 = self._Px0dist.log_prob(x[:,:,0].view(B*N,n)).view(B,N)
            I1 = I1.sum()
        else:
            I1 = 0.
        # Compute approx I2
        tinp = torch.stack([torch.arange(burnin_T, T-1,dtype=torch.get_default_dtype())]*B, dim=0)
        fx = self._system.step(tinp, x[:,:,burnin_T:-1].reshape(B, N*(Ttil-1),n)).view(B*N*(Ttil-1),n)
        dx = x[:,:,burnin_T+1:].reshape(B*N*(Ttil-1),n) - fx
        I2 = self._Qdist.log_prob(dx).sum()
        # Compute I3
        tinp = torch.stack([torch.arange(burnin_T, T,dtype=torch.get_default_dtype())]*B, dim=0)
        tinp = tinp.unsqueeze(1).repeat((1,N,1)).view(B,Ttil*N)
        gx = self._system.observe(tinp, x[:,:,burnin_T:].reshape(B,N*Ttil,n)).view(B,N,Ttil,m)
        dy = y[:,burnin_T:].unsqueeze(1) - gx
        glogprobs = self._Rdist.log_prob(dy.reshape(B*N*Ttil,m)).view(B,N,Ttil)
        I3 = glogprobs.sum()
        return (I1 + I2 + I3) / N 
[docs]    def Q_JSD_marginals(self, y, x, w, wij_tN):
        '''
        Compute Q(theta, theta_k) using approximate JSD marginals
        Args:
            y (torch.tensor): (B,T,m) observartions
            x (torch.tenosr): (B,N,T,n) filtered states
            w (torch.tensor): (B,N,T) filtered weights
            wij_tN (torch.tensor): (B,N,N,T-1) pair-wise smoothing weights
        Returns:
            Q (torch.tensor): (1,) Q(theta, theta_k)
        Notes:
            See http://user.it.uu.se/~thosc112/pubpdf/schonwn2011-2.pdf
            Equations 48-49
        '''
        if torch.any(torch.isnan(wij_tN)):
            import ipdb
            ipdb.set_trace()
        B,N,T,n = x.shape
        m = self._system.ydim
        x = x.detach()
        y = y.detach()
        w = w.detach()
        wij_tN = wij_tN.detach()
        # Compute I1
        I1 = self._Px0dist.log_prob(x[:,:,0].view(B*N,n)).view(B,N)
        I1 = (I1 * w[:,:,0]).sum()
        # Compute I2
        tinp = torch.stack([torch.arange(T-1,dtype=torch.get_default_dtype())]*B, dim=0)
        tinp = tinp.unsqueeze(1).repeat((1,N,1)).view(B,(T-1)*N)
        fx = self._system.step(tinp, x[:,:,:-1].reshape(B,N*(T-1),n)).view(B,N,T-1,n)
        dfx = x[:,:,:-1].unsqueeze(2) - fx.unsqueeze(1)
        flogprobs = self._Qdist.log_prob(dfx.view(B*N*N*(T-1),n)).view(B,N,N,T-1)
        I2 = (flogprobs * wij_tN).sum()
        # Compute I3
        tinp = torch.stack([torch.arange(T,dtype=torch.get_default_dtype())]*B, dim=0)
        tinp = tinp.unsqueeze(1).repeat((1,N,1)).view(B,T*N)
        gx = self._system.observe(tinp, x.reshape(B,N*T,n))
        dy = y.unsqueeze(1) - gx
        glogprobs = self._Rdist.log_prob(dy.view(B*N*T,m)).view(B,N,T)
        I3 = (glogprobs * w).sum()
        return I1 + I2 + I3 
    '''
    Misc methods
    '''
[docs]    def fa_loglikelihood(self, y, xfilt):
        '''
        Estimates log-likehood of filtered particles from faPF
        Eqn 39 from https://arxiv.org/pdf/1703.02419.pdf
        Args:
            y (torch.tensor): (B,T,m) observartions
            xfilt (torch.tenosr): (B,N,T,n) filtered states
        Returns:
            z (torch.tensor): (1,) log p(y_1:T | x_filt, theta)
        '''
        B,N,T,n = xfilt.shape
        m = self._system._ydim
        xfilt = xfilt.detach()
        y = y.detach()
        # propegate
        tinp = torch.stack([torch.arange(T-1,dtype=torch.get_default_dtype())]*B, dim=0)
        fx = self._system.step(tinp, xfilt[:,:,:-1].reshape(B, N*(T-1),n))
        yfx = self._system.observe(tinp, fx).view(B,N,T-1,m)
        dy = y[:,1:].unsqueeze(1) - yfx
        v = self._Rdist.log_prob(dy.view(-1,m)).view(B,N,T-1).exp()
        vmean = v.mean(1)
        log_z = vmean.log().sum(-1).mean()
        return log_z 
        
[docs]    def compute_mean(self, x, w):
        '''
        Computes mean trajectory from weighted samples
        Args:
            x (torch.tenosr): (B,N,T,n) filtered states
            w (torch.tensor): (B,N,T) filtered weights
        Returns
            xmean (torch.tensor): (B,T,n) mean states
        '''
        # ensure normalized
        w = w.clone() / w.sum(1).unsqueeze(1)
        w = w.unsqueeze(-1)
        return (x * w).sum(1)  
from ceem import logger
[docs]def HarmonicDecayScheduler(k, a=50.0):
    a = float(a)
    return a/(k+a)  
[docs]class SAEMTrainer:
    def __init__(self, fapf, y, 
            gamma_sched=HarmonicDecayScheduler,
            max_k=100,
            xlen_cutoff=None):
        '''
        Stochastic-Approximation EM trainer
        Args:
            fapf (faPF): Fully-Adapted Particle Filter
            y (torch.tensor): (B,T,m) observation time-series
        Notes: 
            https://projecteuclid.org/download/pdf_1/euclid.aos/1018031103 
        '''
        self._fapf = fapf
        self._y = y
        self._gamma_sched = gamma_sched
        self._max_k = max_k
        self._xlen_cutoff = xlen_cutoff
[docs]    def train(self, params, callbacks=[]):
        vparams0 = parameters_to_vector(params).clone()
        xsms = []
        t_start = timeit.default_timer()
        for k in range(self._max_k):
            with utils.Timer() as time:
                ## E-step
                xfilt, xfiltr, wfilt, meanll = self._fapf.filter(self._y)
                xsm = self._fapf.FFBSi(xfilt, wfilt)
                xsms.append(xsm)
                if self._xlen_cutoff:
                    if len(xsms) > self._xlen_cutoff:
                        xsms = xsms[-self._xlen_cutoff:]
            logger.logkv('train/Etime', time.dt)
            ## M-step
            with utils.Timer() as time:
                obj = lambda: -self.recursive_Q(xsms, self._y, 0, 0.)
                torchoptimizer(obj, params)
            logger.logkv('train/Mtime', time.dt)
            logger.logkv('train/elapsedtime', timeit.default_timer() - t_start)
            for callback in callbacks:
                callback(k)
            logger.dumpkvs()
        return params 
[docs]    def recursive_Q(self, xs, y, k_, Q):
        '''
        Recursive computation of Q
        Args:
            x (list of torch.tensor): [(B,N,T,n)] iid smoothed trajs
            y (torch.tensor): (B,T,m) observations
            k_ (int): call level
            Q (torch.tensor): (1,) Q(theta) or 0.
        Returns:
            Q (torch.tensor): (1,) Q(theta)
        '''
        if len(xs) == 0:
            return Q
        else:
            gam = self._gamma_sched(k_)
            Q_ = self._fapf.Q_MCEM(y, xs[0])
            Qk = Q * (1-gam) + gam * Q_
            return self.recursive_Q(xs[1:], y, k_+1, Qk)   
[docs]def torchoptimizer(objfun, params, lr=1e-2, nepochs=10):
    opt = torch.optim.Adam(params, lr=lr)
    for epoch in range(nepochs):
        opt.zero_grad()
        loss = objfun()
        loss.backward()
        opt.step()
    return 
[docs]def scipyoptimizer(objfun, params, method='BFGS'):
    vparams0 = parameters_to_vector(params).clone()
    def eval_f(vparams):
        vparams = torch.tensor(vparams)
        vparams_ = parameters_to_vector(params)
        vector_to_parameters(vparams, params)
        with torch.no_grad():
            obj = objfun()
        vector_to_parameters(vparams_, params)
        return obj.detach().numpy()
    def eval_g(vparams):
        vparams = torch.tensor(vparams)
        vparams_ = parameters_to_vector(params)
        vector_to_parameters(vparams, params)
        obj = objfun()
        obj.backward()
        grads = torch.cat([
            p.grad.view(-1) if p.grad is not None else torch.zeros_like(p).view(-1) for p in params
        ])
        grads = grads.detach().numpy()
        vector_to_parameters(vparams0, params)
        return grads
    with utils.Timer() as time:
        if method == 'BFGS':
            opt_result_ = minimize(eval_f, vparams0.detach().numpy(), 
                jac=eval_g, method='BFGS', options={'disp':True})
        elif method == 'Nelder-Mead':
            opt_result_ = minimize(eval_f, vparams0.detach().numpy(), 
                method='Nelder-Mead', options={'disp':True})
    vparams = opt_result_['x']
    return torch.tensor(vparams)