#
# File: opt_criteria.py
#
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from torch.autograd import grad
from torch.nn.utils import parameters_to_vector
from ceem import logger
from ceem.dynamics import (AnalyticDynJacMixin, AnalyticObsJacMixin, DynJacMixin, ObsJacMixin)
[docs]class Criterion:
    def __call__(self, model, x, **kwargs):
        return self.forward(model, x, **kwargs)
[docs]    def batched_forward(self, model, x, **kwargs):
        """
        Forward method for computing criterion, not summed over batch
        Args:
            model (DiscreteDynamicalSystem)
            x (torch.tensor): (B,T,n) system states
        Returns:
            y (torch.tensor): (B,) criterion
        """
        B,T,n = x.shape
        y = torch.stack([self(model,
            x[i:i+1],**kwargs) for i in range(B)])
        return y 
[docs]    def batched_sample_forward(self, model, x, **kwargs):
        """
        Forward method for computing criterion, not summed over batch
        Args:
            model (DiscreteDynamicalSystem)
            x (torch.tensor): (B,N,T,n) system states
        Returns:
            y (torch.tensor): (B,N) criterion
        """
        B,N,T,n = x.shape
        ys = [self.batched_forward(model, x[:,i], **kwargs) for i in range(N)]
        return torch.stack(ys, dim=1) 
[docs]    def forward(self, model, x, **kwargs):
        """
        Forward method for computing criterion
        Args:
            model (DiscreteDynamicalSystem)
            x (torch.tensor): (B,T,n) system states
        Returns:
            criterion (torch.tensor): scalar criterion
        """
        raise NotImplementedError 
[docs]    def jac_x(self, model, x, **kwargs):
        """
        Method for computing jacobian of criterion wrt x
        Args:
            model (DiscreteDynamicalSystem)
            x (torch.tensor): (B,T,n) system states
        Returns:
            jac_x (torch.tensor): (B,T,n) criterion jacobian
        """
        x.requires_grad = True
        criterion = self.forward(model, x, **kwargs)
        jac_x = grad(criterion, x)[0]
        return jac_x 
[docs]    def jac_theta(self, model, x, **kwargs):
        """
        Method for accumulating criterion jacobian on parameters
        Args:
            model (DiscreteDynamicalSystem)
            x (torch.tensor): (B,T,n) system states
            return_grad (bool): if True, returns grad else accumulates to leaf nodes
        """
        criterion = self.forward(model, x, **kwargs)
        return grad(criterion, model.parameters(), allow_unused=True)  
[docs]class GroupCriterion(Criterion):
    """
    A group of Criterion
    """
    def __init__(self, criteria):
        self._criteria = criteria
[docs]    def forward(self, model, x, **kwargs):
        return sum([c(model, x, **kwargs) for c in self._criteria])  
[docs]class STRParamCriterion(Criterion):
    """
    Soft Trust Region on params criterion
    """
    def __init__(self, rho, params):
        self._rho = rho
        self._params = params
        self._vparams0 = parameters_to_vector(params).clone().detach()
[docs]    def forward(self, model, x, **kwargs):
        vparams = parameters_to_vector(self._params)
        return self._rho * torch.sum((vparams - self._vparams0)**2)  
[docs]class SOSCriterion(Criterion):
    """
    Sum-of-squares criterion
    """
[docs]    def forward(self, model, x, **kwargs):
        return 0.5 * (self.residuals(model, x, **kwargs)**2).sum() 
[docs]    def residuals(self, model, x, **kwargs):
        """
        Forward method for computing SOS residuals
        Args:
            model (DiscreteDynamicalSystem)
            x (torch.tensor): (B,T,n) system states
        Returns:
            residuals (torch.tensor): (nresid,) residuals
        """
        raise NotImplementedError 
[docs]    def scaled_jac_x_diag(self, model, x, sparse=False):
        """
        Returns the diagonal blocs of the Jacobian
        Args:
            x (torch.tensor): (B,T,n) system states
        Returns
            jac_resid_x (lambda): lambda function of (b,t) that returns 2 torch.tensors of dimension (n,n)
        """ 
[docs]    def jac_resid_x(self, model, x, sparse=False, sparse_format=sp.csr_matrix, **kwargs):
        """
        Method for computing residuals jacobian wrt x
        Args:
            model (DiscreteDynamicalSystem)
            x (torch.tensor): (B,T,n) system states
            sparse (bool): if True returns a (nresid,B*T*n) sparse_format
            sparse_format (scipy.sparse): sparse matrix format
        Returns:
            jac_resid_x (torch.tensor): (nresid,B,T,n) criterion jacobian
        """
        logger.warn('Using a very inefficient jac_resid_x')
        x = x.detach()
        x.requires_grad_(True)
        resid = self.residuals(model, x, **kwargs)
        jac_resid_x = []
        for r in resid:
            jac_resid_x.append(grad(r, x, retain_graph=True)[0])
        jac_resid_x = torch.stack(jac_resid_x, dim=0)
        nresid, B, T, n = jac_resid_x.shape
        if sparse:
            return sp.csc_matrix(
                jac_resid_x.view(nresid, B * T * n).detach().numpy(), dtype=np.float64)
        else:
            return jac_resid_x  
[docs]class GaussianObservationCriterion(SOSCriterion):
    def __init__(self, Sig_y_inv, t, y, u=None):
        self._t = t
        self._y = y
        self._u = u
        self._size = Sig_y_inv.shape[0]
        if Sig_y_inv.ndim == 1:
            # assuming diagonal
            self._Sig_y_inv_chol = Sig_y_inv.sqrt().unsqueeze(0).unsqueeze(0)
            self._diagcov = True
        elif Sig_y_inv.ndim == 2:
            # assuming full
            self._Sig_y_inv_chol = Sig_y_inv.cholesky().unsqueeze(0).unsqueeze(0)
            self._diagcov = False
        elif Sig_y_inv.ndim == 3:
            # assuming full, timevarying
            self._Sig_y_inv_chol = Sig_y_inv.cholesky().unsqueeze(0)
            self._diagcov = False
[docs]    def apply_inds(self, x, inds):
        if inds is not None:
            t = self._t[inds]
            x = x[inds]
            y = self._y[inds]
            u = self._u[inds] if self._u is not None else None
        else:
            t = self._t
            y = self._y
            u = self._u
        return t, x, y, u 
[docs]    def residuals(self, model, x, inds=None, flatten=True):
        t, x, y, u = self.apply_inds(x, inds)
        ypred = model.observe(t, x, u)
        err = ypred - y
        if self._diagcov:
            resid = self._Sig_y_inv_chol * err
        else:
            resid = (self._Sig_y_inv_chol @ err.unsqueeze(-1)).squeeze(-1)
        return resid.view(-1) if flatten else resid 
[docs]    def scaled_jac_x_diag(self, model, x, inds=None):
        t, x, y, u = self.apply_inds(x, inds)
        jac_obs_x = model.jac_obs_x(t, x, u)
        if self._diagcov:
            jac_resid_x_ = self._Sig_y_inv_chol.unsqueeze(-1) * jac_obs_x
        else:
            jac_resid_x_ = self._Sig_y_inv_chol @ jac_obs_x
        J = jac_resid_x_.detach()
        return J 
[docs]    def jac_resid_x(self, model, x, sparse=False, sparse_format=sp.csr_matrix, inds=None):
        t, x, y, u = self.apply_inds(x, inds)
        if isinstance(model, ObsJacMixin):
            jac_resid_x_ = self.scaled_jac_x_diag(model, x)
            B, T, m, n = jac_resid_x_.shape
            # currently (B,T,m,n)
            if sparse:
                idx = torch.arange(B * T)
                idptr = torch.arange(B * T + 1)
                jac_resid_x = sp.bsr_matrix((jac_resid_x_.view(B * T, m, n), idx, idptr),
                                            shape=(T * m * B, T * n * B), dtype=np.float64)
                return sparse_format(jac_resid_x, dtype=np.float64)
            else:
                jac_resid_x = torch.zeros(B * T * m, B, T, n)
                for b in range(B):
                    for t in range(T):
                        jac_resid_x[(b * T + t) * m:(b * T + t + 1) *
                                    m, b, t, :] = jac_resid_x_[b, t]
                return jac_resid_x
        else:
            print('model needs an ObsJacMixin')
            raise NotImplementedError  
[docs]class STRStateCriterion(SOSCriterion):
    """
    Soft Trust Region on states 
    """
    def __init__(self, rho, x0):
        self._size = x0.shape[-1]
        self._rho = rho
        self._x0 = x0
[docs]    def residuals(self, model, x, inds=None, flatten=True):
        res = self._rho * (x - self._x0)
        return res.view(-1) if flatten else res 
[docs]    def scaled_jac_x_diag(self, model, x, inds=None):
        B, T, n = x.shape
        jac_resid = (self._rho * torch.eye(n).unsqueeze(0).unsqueeze(0)).repeat(B, T, 1, 1)
        return jac_resid.detach() 
[docs]    def jac_resid_x(self, model, x, sparse=False, sparse_format=sp.csr_matrix):
        B, T, n = x.shape
        if sparse:
            return self._rho * sp.eye(B * T * n, format=sparse_format.format, dtype=np.float64)
        else:
            return self._rho * torch.eye(B * T * n).to(dtype=x.dtype).view(B * T * n, B, T, n)  
[docs]class GaussianX0Criterion(SOSCriterion):
    """
    Soft Trust Region on states 
    """
    def __init__(self, x0, Sig_x0_inv):
        self._size = x0.shape[-1]
        self._x0 = x0
        if Sig_x0_inv.ndim == 1:
            # assuming diagonal
            self._Sig_x0_inv_chol = Sig_x0_inv.sqrt().unsqueeze(0).unsqueeze(0)
            self._diagcov = True
        else:
            self._Sig_x0_inv_chol = Sig_x0_inv.cholesky().unsqueeze(0).unsqueeze(0)
            self._diagcov = False
[docs]    def residuals(self, model, x, inds=None, flatten=True):
        err = (x[:,0] - self._x0)
        if self._diagcov:
            resid = self._Sig_x0_inv_chol * err
        else:
            resid = (self._Sig_x0_inv_chol @ err.unsqueeze(-1)).squeeze(-1)
        return resid.view(-1) if flatten else resid 
[docs]    def jac_resid_x(self, model, x, sparse=False, sparse_format=sp.csr_matrix):
        B, T, n = x.shape
        if self._diagcov:
            cov = self._Sig_x0_inv_chol.squeeze().diag()
        else:
            cov = self._Sig_x0_inv_chol
        # cov = cov.unsqueeze(0).repeat(B,1,1)
        retval = torch.zeros(B*n, B, T, n, dtype=x.dtype)
        for b in range(B):
            retval[b*n:(b+1)*n,b,0,:] = cov
        # retval[:,:,0,:] = cov.view(B*n,B,n)
        if sparse:
            return sparse_format(retval.view(-1,B*T*n).detach().numpy(), dtype=np.float64)
        else:
            return retval  
[docs]class GaussianDynamicsCriterion(SOSCriterion):
    def __init__(self, Sig_w_inv, t, u=None):
        self._t = t
        self._u = u
        self._size = Sig_w_inv.shape[0]
        if Sig_w_inv.ndim == 1:
            # assuming diagonal
            self._Sig_w_inv_chol = Sig_w_inv.sqrt().unsqueeze(0).unsqueeze(0)
            self._diagcov = True
        else:
            self._Sig_w_inv_chol = Sig_w_inv.cholesky().unsqueeze(0).unsqueeze(0)
            self._diagcov = False
[docs]    def apply_inds(self, x, inds):
        if inds is not None:
            t = self._t[inds]
            x = x[inds]
            u = self._u[inds] if self._u is not None else None
        else:
            t = self._t
            u = self._u
        return t, x, u 
[docs]    def residuals(self, model, x, inds=None, flatten=True):
        t, x, u = self.apply_inds(x, inds)
        u = u[:, :-1] if u is not None else None
        xpred = model.step(self._t[:, :-1], x[:, :-1], u)
        err = xpred - x[:, 1:]
        if self._diagcov:
            resid = self._Sig_w_inv_chol * err
        else:
            resid = (self._Sig_w_inv_chol @ err.unsqueeze(-1)).squeeze(-1)
        return resid.view(-1) if flatten else resid 
[docs]    def scaled_jac_x_diag(self, model, x, inds=None):
        t, x, u = self.apply_inds(x, inds)
        u = u[:, :-1] if u is not None else None
        jac_dyn_x_ = model.jac_step_x(t[:, :-1], x[:, :-1], u)
        if self._diagcov:
            jac_resid_x_ = self._Sig_w_inv_chol.unsqueeze(-1) * jac_dyn_x_
            neyew = -torch.diag_embed(self._Sig_w_inv_chol.view(-1))
        else:
            jac_resid_x_ = self._Sig_w_inv_chol @ jac_dyn_x_
            neyew = -self._Sig_w_inv_chol.view(-1)
        J = jac_resid_x_.detach()
        return J, neyew 
[docs]    def jac_resid_x(self, model, x, sparse=False, sparse_format=sp.csr_matrix, inds=None):
        t, x, u = self.apply_inds(x, inds)
        if isinstance(model, DynJacMixin):
            B, T, n = x.shape
            u = u[:, :-1] if u is not None else None
            jac_dyn_x_ = model.jac_step_x(t[:, :-1], x[:, :-1], u)
            jac_resid_x_, neyew = self.scaled_jac_x_diag(model, x)
            # currently (B,T-1,n,n)
            if sparse:
                # Create the right indices given bloc diagonal
                # There is a slight subtlety as intersection between batches
                idx_ = torch.arange((T - 1))
                idx = torch.cat([idx_ + b * T for b in range(B)])
                idptr = torch.arange(B * (T - 1) + 1)
                # Create diagonal matrix with jacobians
                J1 = sp.bsr_matrix((jac_resid_x_.view(B * (T - 1), n, n), idx, idptr),
                                   shape=((T - 1) * n * B, T * n * B), dtype=np.float64)
                # Create diagonal matrix with negative identity
                E = (neyew.unsqueeze(0).repeat(B * (T - 1), 1, 1))
                J2 = sp.bsr_matrix((E, idx + 1, idptr), shape=((T - 1) * n * B, T * n * B),
                                   dtype=np.float64)
                return sparse_format(J1 + J2, dtype=np.float64)
            else:
                jac_resid_x = torch.zeros(B * (T - 1) * n, B, T, n)
                for b in range(B):
                    for t in range(T - 1):
                        jac_resid_x[(b * (T - 1) + t) * n:(b * (T - 1) + t + 1) *
                                    n, b, t, :] = jac_resid_x_[b, t]
                        jac_resid_x[(b * (T - 1) + t) * n:(b * (T - 1) + t + 1) * n, b, t +
                                    1, :] = neyew
                return jac_resid_x
        else:
            return super().jac_resid_x(model, t, x)  
[docs]class BasicGroupSOSCriterion(SOSCriterion):
    """
    Group of SOSCriterion instances
    """ 
[docs]class GroupSOSCriterion(BasicGroupSOSCriterion):
    """
    Group of SOSCriterion instances
    """
    def __init__(self, criteria):
        self._criteria = criteria
[docs]    def residuals(self, model, x, **kwargs):
        residuals_list = [c.residuals(model, x, **kwargs) for c in self._criteria]
        return torch.cat(residuals_list, dim=0) 
[docs]    def jac_resid_x(self, model, x, sparse=False, sparse_format=sp.csr_matrix, **kwargs):
        jacs = [
            c.jac_resid_x(model, x, sparse=sparse, sparse_format=sp.csr_matrix, **kwargs)
            for c in self._criteria
        ]
        B, T, n = x.shape
        if sparse:
            if sparse_format is not sp.coo_matrix:
                return sparse_format(sp.vstack(jacs), dtype=np.float64)
            else:
                return sp.vstack(jacs)
        else:
            return torch.cat(jacs, dim=0)  
[docs]class BlockSparseGroupSOSCriterion(BasicGroupSOSCriterion):
    """
    Group of SOSCriterion instances with careful design of the sparsity pattern in the Jacobian matrix.
    All SOSCriterion but the last one must have block diagonal jacobians and a `.scaled_jac_x_diag` method.
    The last SOSCriterion, typically the dynamics, has two blocks on the primary and first upper diagonal.
    """
    def __init__(self, criteria):
        self._criteria = criteria
        self._size = sum([c._size for c in criteria])
        self._mem = {}
[docs]    def residuals(self, model, x, **kwargs):
        B, T, n = x.shape
        residuals = torch.zeros(B, T, self._size)
        # Assign all residuals of length T residuals
        s = 0
        for c in self._criteria[:-1]:
            residuals[:, :, s:s + c._size] = c.residuals(model, x, **kwargs, flatten=False)
            s += c._size
        # Assign dynamics residuals (length T-1)
        residuals[:, :-1, -n:] = self._criteria[-1].residuals(model, x, **kwargs, flatten=False)
        return residuals.view(-1) 
[docs]    def jac_resid_x(self, model, x, sparse_format=sp.bsr_matrix, **kwargs):
        B, T, n = x.shape
        # Check for memoized sparsity pattern
        key = f'{B},{T},{n}'
        if key in self._mem:
            pattern = self._mem[key]
        else:
            pattern = {}
            # For J1
            idx_ = torch.arange(T)
            pattern['idx1'] = torch.cat([idx_ + b * T for b in range(B)])
            pattern['idptr1'] = torch.arange(B * T + 1)
            # For J2
            pattern['idx2'] = pattern['idx1'].view(B, T)[:, 1:].reshape(-1)
            idptr_ = torch.arange(T)
            pattern['idptr2'] = torch.cat([idptr_ + b * (T - 1) for b in range(B)] +
                                          [torch.tensor([B * (T - 1)])])
            self._mem[key] = pattern
        # Obtain building blocks from subcriteria
        J = []
        for i in range(len(self._criteria) - 1):
            J.append(self._criteria[i].scaled_jac_x_diag(model, x))
        m = self._size
        Jd, E = self._criteria[-1].scaled_jac_x_diag(model, x)
        Jd = F.pad(Jd, pad=(0, 0, 0, 0, 0, 1), mode='constant', value=0)
        # Build main diagonal
        D = torch.cat(J + [Jd], 2)
        J1 = sp.bsr_matrix((D.view(-1, m, n), pattern['idx1'], pattern['idptr1']),
                           shape=(B * T * m, B * T * n), dtype=np.float64)
        # Add off-diag term
        E = (E.unsqueeze(0).unsqueeze(0).repeat(B, (T - 1), 1, 1))
        E = F.pad(E, (0, 0, m - n, 0), mode='constant', value=0).view(B * (T - 1), m, n)
        J2 = sp.bsr_matrix((E, pattern['idx2'], pattern['idptr2']), shape=(B * T * m, B * T * n),
                           dtype=np.float64)
        return sparse_format(J1 + J2)