Source code for ceem.particleem

import torch
from ceem import utils
from torch.distributions.categorical import Categorical

from torch.distributions.multivariate_normal import MultivariateNormal


from torch.nn.utils import parameters_to_vector, vector_to_parameters

from scipy.optimize import minimize

from ceem import utils


import timeit


[docs]class Resampler: def __init__(self, method): ''' Class containing Multinomial, Systematic, Stratified resamplers Args: method (str): 'multinomial', 'stratified', or 'systematic' ''' if method not in ['multinomial', 'stratified', 'systematic']: raise NotImplementedError else: self._method = method def __call__(self, weights): ''' Resampling method Args: weights (torch.tensor): (B, N) particle weights Returns: indices (torch.tensor): (B, N) particle indices ''' return getattr(self, self._method)(weights)
[docs] def dist(self, weights): return Categorical(weights)
[docs] def icdf(self, weights, cds): ''' Linear-time inverse CDF for a categorical distribution Args: weights (torch.tensor): (B, N) particle weights cds (torch.tensor): (B,N) ascending ordered cumulative densities Returns: indices (torch.tensor): (B, N) particle indices ''' B,N = weights.shape indices = torch.zeros(B,N, dtype=torch.long) for b in range(B): i = 0 cumden = weights[b,i] prev_cds = -1.0 for n in range(N): assert prev_cds < cds[b,n], 'cds must be ascending ordered.' prev_cds = cds[b,n] while cds[b,n] > cumden: i += 1 cumden += weights[b,i] indices[b,n] = i return indices
[docs] def multinomial(self, weights): N = weights.shape[1] dist = self.dist(weights) return dist.sample((N,)).T
[docs] def stratified(self, weights): B,N = weights.shape cds = (torch.rand(B,N) + torch.arange(N, dtype=torch.get_default_dtype()).unsqueeze(0))*(1./N) print('Stratified') print(cds) return self.icdf(weights, cds)
[docs] def systematic(self, weights): B,N = weights.shape cds = (torch.rand(B,1) + torch.arange(N, dtype=torch.get_default_dtype()).unsqueeze(0))*(1./N) return self.icdf(weights, cds)
[docs]class faPF: def __init__(self, N, system, Q, R, Px0, x0mean=None, resampler=Resampler('systematic'), FFBSi_N = None, burnin_T = None, ess_frac = 0.5): """ Fully-Adapted Particle Filter Args: N (int): system (DiscreteDynamics): Q (torch.tensor): (xdim,xdim) torch tensor R (torch.tensor): (ydim,ydim) torch tensor Px0 (torch.tensor): (xdim,xdim) torch tensor x0mean (torch.tensor): (1,xdim) torch tensor resampler (Resampler): resampler FFBSi_N (int): number of backward trajectories to sample burnin_T (int or None): time skipped when computing MCEM Q ess_frac (float): Effective Sample Size fraction """ self._N = N self._system = system self._xdim = system._xdim self._ydim = system._ydim self._Q = Q.unsqueeze(0) self._Qinv = Q.inverse().unsqueeze(0) self._Qdist = MultivariateNormal(torch.zeros(self._xdim), Q) self._x0mean = x0mean if x0mean is not None else torch.zeros(1, self._xdim) self._Px0 = Px0.unsqueeze(0) self._Px0dist = MultivariateNormal(self._x0mean, Px0) self._R = R.unsqueeze(0) self._Rinv = R.inverse().unsqueeze(0) self._Rdist = MultivariateNormal(torch.zeros(self._ydim), R) self._resampler = resampler self._FFBSi_N = FFBSi_N if FFBSi_N else max(N//10,1) self._burnin_T = burnin_T self._ess_frac = ess_frac
[docs] def filter(self, y): ''' Run filter Args: y (torch.tensor): (B,T,m) observartions Returns: x (torch.tenosr): (B,N,T,n) filtered states w (torch.tensor): (B,N,T) filtered weights ''' N = self._N B,T,m = y.shape x = torch.zeros(B, N, T, self._xdim) xr = torch.zeros_like(x) w = torch.ones(B, N, T, dtype=torch.get_default_dtype())/N # (1/N for faPF) # sample initial distribution from p(x0 | y0) x[:,:,0:1] = self.sample_initial(y[:,0:1], N) v_unnorm = torch.zeros(B,N,T-1) log_prev_v = torch.ones(B,N, dtype=torch.get_default_dtype()) for t in range(1,T): # sample x_t from p(x_t | x_t-1, y_t) xtm1 = x[:,:,t-1:t] yt = y[:,t:t+1] # resample resampdist = self.get_resampling_distribution(t-1, xtm1) log_vt = resampdist.log_prob(yt.repeat(1,N,1).view(B*N,m)).view(B,N) v_unnorm[:,:,t-1] = log_vt.exp() log_vt += log_prev_v log_vt -= log_vt.max(dim=1)[0].unsqueeze(1) vt = log_vt.exp() vt /= vt.sum(dim=1).unsqueeze(1) inds = self._resampler(vt.clone()) xtm1r = xtm1.clone() for b in range(B): # adaptive ESS = 1./(vt[b]**2).sum() if ESS < self._ess_frac * N: xtm1r[b] = xtm1[b, inds[b]] log_prev_v[b] = torch.ones(N, dtype=torch.get_default_dtype()) # print('Resample at t=%d, ESS=%.3f'%(t,ESS)) else: log_prev_v[b] = log_vt[b] xr[:,:,t-1:t] = xtm1r # propegate sampdist = self.get_sampling_distribution(t-1, xtm1r, yt) xt = sampdist.sample().reshape(B,N,1,self._xdim) x[:,:,t:t+1] = xt mean_ll = v_unnorm.mean(1).log().mean() return x, xr, w, mean_ll
[docs] def sample_initial(self, y0, N): ''' Sample from p(x0 | y0) Args: y0 (torch.tensor): (B,1,m) initial observation Returns: x0 (torch.tensor): (B,N,1,n) samples of x0 Notes: Eqns 352-354 https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf ''' B, _, m = y0.shape y0 = y0.squeeze(1) mu_x = self._x0mean.repeat(B,1) tinp = torch.tensor([0] * B).unsqueeze(1).to(dtype=torch.get_default_dtype()) C = self._system.jac_obs_x(tinp, mu_x.unsqueeze(1)).detach().squeeze(1) Sig_c = (C @ self._Px0).transpose(-2,-1) Sig_y = self._R + C @ self._Px0 @ C.transpose(-2,-1) muhat_x = mu_x + (Sig_c @ torch.solve( y0.unsqueeze(-1) - C @ mu_x.unsqueeze(-1), Sig_y)[0]).squeeze(-1) Sighat_x = self._Px0 - Sig_c @ torch.solve(Sig_c.transpose(-2,-1), Sig_y)[0] x0dist = MultivariateNormal(muhat_x, Sighat_x) # xtest = self._Px0dist.sample((100,)) # ll_y_x = self._Rdist.log_prob((C @ xtest.transpose(-2,-1)).squeeze(-1) - y0).unsqueeze(-1) # ll_x = self._Px0dist.log_prob(xtest) # joint_ll = ll_y_x + ll_x # ll_x_y = x0dist.log_prob(xtest) # diff = ll_x_y - joint_ll # above diff is constant - implying this math is correct return x0dist.sample((N,)).transpose(0,1).unsqueeze(2)
[docs] def get_sampling_distribution(self, t, xtm1, yt): ''' Sampling distribution p(xt | xt-1, yt) Args: t (int): time-index xtm1 (torch.tensor): (B,N,1,n) state-particles yt (torch.tensor): (B,1,m) observation Returns: dist (MultivariateGaussian): p(xt | xt-1, yt) Notes: Implementation of https://link.springer.com/content/pdf/10.1023%2FA%3A1008935410038.pdf Equation 18-20 ''' B,N,_,xdim = xtm1.shape m = yt.shape[-1] tinp = torch.tensor([t] * B).unsqueeze(1).to(dtype=torch.get_default_dtype()) fxtm1 = self._system.step(tinp-1, xtm1.view(B*N,1,xdim)) Ht = self._system.jac_obs_x(t, fxtm1).detach().squeeze(1) HtT = Ht.transpose(-2,-1) yt_exp = self._system.observe(tinp, fxtm1) fxtm1 = fxtm1.squeeze(1) Sigtinv = self._Qinv + HtT @ self._Rinv @ Ht Sig = Sigtinv.inverse() yt_ = yt.repeat(1,N,1).view(B*N,m,1) yt_ += Ht @ fxtm1.unsqueeze(-1) - yt_exp.transpose(-2,-1) mt = Sig @ (self._Qinv @ fxtm1.unsqueeze(-1) + HtT @ self._Rinv @ yt_) mt = mt.squeeze(-1) return MultivariateNormal(mt, Sig)
[docs] def get_resampling_distribution(self, t, xtm1): ''' Resampling distribution p(yt | xt-1) Args: t (int): time-index xtm1 (torch.tensor): (B,N,1,n) state-particles Returns: dist (MultivariateGaussian): p(yt | xt-1) Notes: Implementation of https://link.springer.com/content/pdf/10.1023%2FA%3A1008935410038.pdf Equation 14 ''' ydim = self._system.ydim B,N,_,xdim = xtm1.shape tinp = torch.tensor([t] * B).unsqueeze(1).to(dtype=torch.get_default_dtype()) fxtm1 = self._system.step(tinp-1, xtm1.view(B*N,1,xdim)) Ht = self._system.jac_obs_x(t, fxtm1).detach().squeeze(1) HtT = Ht.transpose(-2,-1) yt = self._system.observe(tinp, fxtm1) mt = yt.squeeze(1) Sig = (self._R + Ht @ self._Q @ HtT) return MultivariateNormal(mt, Sig)
''' Smoothing Methods '''
[docs] def FFBSm(self, y, x, w, return_wij_tN=False): ''' Compute the smoothing marginal distributions Args: y (torch.tensor): (B,T,m) observartions x (torch.tenosr): (B,N,T,n) filtered states w (torch.tensor): (B,N,T) filtered weights return_wij_tN (bool): return pair-wise smoothing weights Returns: xsm (torch.tenosr): (B,N,T,n) smoothed states wsm (torch.tensor): (B,N,T) smoothed weights wij_tN (torch.tensor): (B,N,N,T) pairwise smoothing weights Notes: Implemenation of Forward Filtering-Backward Smoothing https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf Equation 49 ''' B,N,T,n = x.shape wsm = w.clone() wsm /= wsm.sum(1).unsqueeze(1) # normalize wij_tN = torch.zeros(B,N,N,T-1) for t in range(T-1): tau = T - t - 2 # compute p(X_tau+1 | X_tau) tinp = torch.tensor([t] * B).unsqueeze(1).to(dtype=torch.get_default_dtype()) fxtau = self._system.step(tinp, x[:,:,tau]) dxtaup1 = x[:,:,tau+1].unsqueeze(2) - fxtau.unsqueeze(1) # (B,N,N,n) fprobs = self._Qdist.log_prob(dxtaup1.view(B*N*N,n)).view(B,N,N).exp() den = (fprobs * w[:,:,tau:tau+1]).sum(-2).unsqueeze(-2) wij_tN[:,:,:,tau] = wsm[:,:,tau].unsqueeze(-1) * (wsm[:,:,tau+1:tau+2].transpose(-2,-1) * fprobs / den) wsm[:,:,tau] = wij_tN[:,:,:,tau].sum(-1) wsm[:,:,tau] /= wsm[:,:,tau].sum(1).unsqueeze(1) # normalize wij_tN[:,:,:,tau] /= wij_tN[:,:,:,tau].sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1) xsm = x.clone() if return_wij_tN: return xsm, wsm, wij_tN else: return xsm, wsm
[docs] def FFBSi(self, x, w): ''' Sample trajectories from JSD using FFBSi Args: x (torch.tenosr): (B,N,T,n) filtered states w (torch.tensor): (B,N,T) filtered weights Returns: xsamps (torch.tesnor): (B, Ns, T, n) sampled state trajs Notes: See http://users.isy.liu.se/rt/schon/Publications/LindstenS2013.pdf Algorithm 4 ''' B,N,T,n = x.shape Ns = self._FFBSi_N xsamps = torch.zeros((B, Ns, T, n)) # Sample at T bj = Categorical(w[:,:,-1]).sample((Ns,)).T for b in range(B): xsamps[b,:,-1] = x[b, bj[b],-1] for tau in range(T-1): t = T - tau - 2 tinp = torch.tensor([t] * B).unsqueeze(1).to(dtype=torch.get_default_dtype()) fxt = self._system.step(tinp, x[:,:,t]) dxtp1 = xsamps[:,:,t+1].unsqueeze(2) - fxt.unsqueeze(1) # (B,Ns,N,n) flgprobs = self._Qdist.log_prob(dxtp1.view(B*Ns*N,n)).view(B,Ns,N) wlgfprobs = flgprobs + w[:,:,t].unsqueeze(1).log() bj = Categorical(logits=flgprobs).sample() for b in range(B): xsamps[b,:,t] = x[b, bj[b],t] return xsamps
''' Methods for computing Q(theta | theta_k) '''
[docs] def Q_MCEM(self, y, x): ''' Compute Monte-Carlo appx Q(theta, theta_k) using iid x_1:T ~ p(x_1:T|y_1:T) Args: y (torch.tensor): (B,T,m) observartions x (torch.tenosr): (B,N,T,n) sampled state trajs Returns: Q (torch.tensor): (1,) Q(theta, theta_k) ''' B,N,T,n = x.shape m = self._system.ydim burnin_T = self._burnin_T if self._burnin_T is not None else T//10 Ttil = T - burnin_T x = x.detach() y = y.detach() # Compute I1 if burnin_T == 0: I1 = self._Px0dist.log_prob(x[:,:,0].view(B*N,n)).view(B,N) I1 = I1.sum() else: I1 = 0. # Compute approx I2 tinp = torch.stack([torch.arange(burnin_T, T-1,dtype=torch.get_default_dtype())]*B, dim=0) fx = self._system.step(tinp, x[:,:,burnin_T:-1].reshape(B, N*(Ttil-1),n)).view(B*N*(Ttil-1),n) dx = x[:,:,burnin_T+1:].reshape(B*N*(Ttil-1),n) - fx I2 = self._Qdist.log_prob(dx).sum() # Compute I3 tinp = torch.stack([torch.arange(burnin_T, T,dtype=torch.get_default_dtype())]*B, dim=0) tinp = tinp.unsqueeze(1).repeat((1,N,1)).view(B,Ttil*N) gx = self._system.observe(tinp, x[:,:,burnin_T:].reshape(B,N*Ttil,n)).view(B,N,Ttil,m) dy = y[:,burnin_T:].unsqueeze(1) - gx glogprobs = self._Rdist.log_prob(dy.reshape(B*N*Ttil,m)).view(B,N,Ttil) I3 = glogprobs.sum() return (I1 + I2 + I3) / N
[docs] def Q_JSD_marginals(self, y, x, w, wij_tN): ''' Compute Q(theta, theta_k) using approximate JSD marginals Args: y (torch.tensor): (B,T,m) observartions x (torch.tenosr): (B,N,T,n) filtered states w (torch.tensor): (B,N,T) filtered weights wij_tN (torch.tensor): (B,N,N,T-1) pair-wise smoothing weights Returns: Q (torch.tensor): (1,) Q(theta, theta_k) Notes: See http://user.it.uu.se/~thosc112/pubpdf/schonwn2011-2.pdf Equations 48-49 ''' if torch.any(torch.isnan(wij_tN)): import ipdb ipdb.set_trace() B,N,T,n = x.shape m = self._system.ydim x = x.detach() y = y.detach() w = w.detach() wij_tN = wij_tN.detach() # Compute I1 I1 = self._Px0dist.log_prob(x[:,:,0].view(B*N,n)).view(B,N) I1 = (I1 * w[:,:,0]).sum() # Compute I2 tinp = torch.stack([torch.arange(T-1,dtype=torch.get_default_dtype())]*B, dim=0) tinp = tinp.unsqueeze(1).repeat((1,N,1)).view(B,(T-1)*N) fx = self._system.step(tinp, x[:,:,:-1].reshape(B,N*(T-1),n)).view(B,N,T-1,n) dfx = x[:,:,:-1].unsqueeze(2) - fx.unsqueeze(1) flogprobs = self._Qdist.log_prob(dfx.view(B*N*N*(T-1),n)).view(B,N,N,T-1) I2 = (flogprobs * wij_tN).sum() # Compute I3 tinp = torch.stack([torch.arange(T,dtype=torch.get_default_dtype())]*B, dim=0) tinp = tinp.unsqueeze(1).repeat((1,N,1)).view(B,T*N) gx = self._system.observe(tinp, x.reshape(B,N*T,n)) dy = y.unsqueeze(1) - gx glogprobs = self._Rdist.log_prob(dy.view(B*N*T,m)).view(B,N,T) I3 = (glogprobs * w).sum() return I1 + I2 + I3
''' Misc methods '''
[docs] def fa_loglikelihood(self, y, xfilt): ''' Estimates log-likehood of filtered particles from faPF Eqn 39 from https://arxiv.org/pdf/1703.02419.pdf Args: y (torch.tensor): (B,T,m) observartions xfilt (torch.tenosr): (B,N,T,n) filtered states Returns: z (torch.tensor): (1,) log p(y_1:T | x_filt, theta) ''' B,N,T,n = xfilt.shape m = self._system._ydim xfilt = xfilt.detach() y = y.detach() # propegate tinp = torch.stack([torch.arange(T-1,dtype=torch.get_default_dtype())]*B, dim=0) fx = self._system.step(tinp, xfilt[:,:,:-1].reshape(B, N*(T-1),n)) yfx = self._system.observe(tinp, fx).view(B,N,T-1,m) dy = y[:,1:].unsqueeze(1) - yfx v = self._Rdist.log_prob(dy.view(-1,m)).view(B,N,T-1).exp() vmean = v.mean(1) log_z = vmean.log().sum(-1).mean() return log_z
[docs] def compute_mean(self, x, w): ''' Computes mean trajectory from weighted samples Args: x (torch.tenosr): (B,N,T,n) filtered states w (torch.tensor): (B,N,T) filtered weights Returns xmean (torch.tensor): (B,T,n) mean states ''' # ensure normalized w = w.clone() / w.sum(1).unsqueeze(1) w = w.unsqueeze(-1) return (x * w).sum(1)
from ceem import logger
[docs]def HarmonicDecayScheduler(k, a=50.0): a = float(a) return a/(k+a)
[docs]class SAEMTrainer: def __init__(self, fapf, y, gamma_sched=HarmonicDecayScheduler, max_k=100, xlen_cutoff=None): ''' Stochastic-Approximation EM trainer Args: fapf (faPF): Fully-Adapted Particle Filter y (torch.tensor): (B,T,m) observation time-series Notes: https://projecteuclid.org/download/pdf_1/euclid.aos/1018031103 ''' self._fapf = fapf self._y = y self._gamma_sched = gamma_sched self._max_k = max_k self._xlen_cutoff = xlen_cutoff
[docs] def train(self, params, callbacks=[]): vparams0 = parameters_to_vector(params).clone() xsms = [] t_start = timeit.default_timer() for k in range(self._max_k): with utils.Timer() as time: ## E-step xfilt, xfiltr, wfilt, meanll = self._fapf.filter(self._y) xsm = self._fapf.FFBSi(xfilt, wfilt) xsms.append(xsm) if self._xlen_cutoff: if len(xsms) > self._xlen_cutoff: xsms = xsms[-self._xlen_cutoff:] logger.logkv('train/Etime', time.dt) ## M-step with utils.Timer() as time: obj = lambda: -self.recursive_Q(xsms, self._y, 0, 0.) torchoptimizer(obj, params) logger.logkv('train/Mtime', time.dt) logger.logkv('train/elapsedtime', timeit.default_timer() - t_start) for callback in callbacks: callback(k) logger.dumpkvs() return params
[docs] def recursive_Q(self, xs, y, k_, Q): ''' Recursive computation of Q Args: x (list of torch.tensor): [(B,N,T,n)] iid smoothed trajs y (torch.tensor): (B,T,m) observations k_ (int): call level Q (torch.tensor): (1,) Q(theta) or 0. Returns: Q (torch.tensor): (1,) Q(theta) ''' if len(xs) == 0: return Q else: gam = self._gamma_sched(k_) Q_ = self._fapf.Q_MCEM(y, xs[0]) Qk = Q * (1-gam) + gam * Q_ return self.recursive_Q(xs[1:], y, k_+1, Qk)
[docs]def torchoptimizer(objfun, params, lr=1e-2, nepochs=10): opt = torch.optim.Adam(params, lr=lr) for epoch in range(nepochs): opt.zero_grad() loss = objfun() loss.backward() opt.step() return
[docs]def scipyoptimizer(objfun, params, method='BFGS'): vparams0 = parameters_to_vector(params).clone() def eval_f(vparams): vparams = torch.tensor(vparams) vparams_ = parameters_to_vector(params) vector_to_parameters(vparams, params) with torch.no_grad(): obj = objfun() vector_to_parameters(vparams_, params) return obj.detach().numpy() def eval_g(vparams): vparams = torch.tensor(vparams) vparams_ = parameters_to_vector(params) vector_to_parameters(vparams, params) obj = objfun() obj.backward() grads = torch.cat([ p.grad.view(-1) if p.grad is not None else torch.zeros_like(p).view(-1) for p in params ]) grads = grads.detach().numpy() vector_to_parameters(vparams0, params) return grads with utils.Timer() as time: if method == 'BFGS': opt_result_ = minimize(eval_f, vparams0.detach().numpy(), jac=eval_g, method='BFGS', options={'disp':True}) elif method == 'Nelder-Mead': opt_result_ = minimize(eval_f, vparams0.detach().numpy(), method='Nelder-Mead', options={'disp':True}) vparams = opt_result_['x'] return torch.tensor(vparams)