from collections import OrderedDict
import torch
from torch import nn
from ceem.dynamics import (AnalyticObsJacMixin, C2DSystem, DynJacMixin, KinoDynamicalSystem)
[docs]class SpringMassDamper(KinoDynamicalSystem, C2DSystem, nn.Module, DynJacMixin, AnalyticObsJacMixin):
    """Generalized Spring Mass Damper System
    """
    def __init__(self, M, D, K, dt, method='midpoint'):
        """
           Args:
              M (torch.tensor): (qn,qn) positive-definite mass matrix
              D (torch.tensor): (qn,qn) positive-definite damping matrix
              K (torch.tensor): (qn,qn) positive-definite stiffness matrix
        """
        C2DSystem.__init__(self, dt=dt, method=method)
        nn.Module.__init__(self)
        qn = M.shape[0]
        Minv = M.inverse()
        Minv_chol = Minv.unsqueeze(0).cholesky()
        D_chol = D.unsqueeze(0).cholesky()
        K_chol = K.unsqueeze(0).cholesky()
        self._Minv_chol = nn.Parameter(Minv_chol.unsqueeze(0))
        self._D_chol = nn.Parameter(D_chol.unsqueeze(0))
        self._K_chol = nn.Parameter(K_chol.unsqueeze(0))
        self._qdim = qn
        self._vdim = qn
[docs]    def dynamics(self, t, q, v, u):
        Minv = self._Minv_chol @ self._Minv_chol.transpose(2, 3)
        D = self._D_chol @ self._D_chol.transpose(2, 3)
        K = self._K_chol @ self._K_chol.transpose(2, 3)
        vdot = -(Minv @ (D @ v.unsqueeze(3) + K @ q.unsqueeze(3))).squeeze(3)
        return vdot 
[docs]    def kinematics(self, t, q, v, u):
        return v 
[docs]    def observe(self, t, x, u=None):
        return x * 1.0 
    # def kinematics_jacobian(self, t, q, v, u):
    #     B,T,qn = q.shape
    #     dqdotdq = torch.zeros(B,T,self._qdim,self._qdim)
    #     dqdotdv = torch.eye(self._qdim).expand(B,T,self._qdim,self._vdim)
    #     return torch.cat([dqdotdq,dqdotdv], dim=-1)
    # def dynamics_jacobian(self, t, q, v, u):
    #     B,T,qn = q.shape
    #     Minv = self._Minv_chol @ self._Minv_chol.transpose(2, 3)
    #     D = self._D_chol @ self._D_chol.transpose(2, 3)
    #     K = self._K_chol @ self._K_chol.transpose(2, 3)
    #     return -(Minv @ torch.cat([K, D], dim=-1)).expand(B,T,self._vdim,self._qdim + self._vdim)
    # def jac_dyn_x(self, t, x, u):
    #     """
    #     Return jac_x (xdot)
    #     """
    #     B,T,n = x.shape
    #     q = x[:,:,:self._qdim]
    #     v = x[:,:,self._qdim:]
    #     dqdotdx = self.kinematics_jacobian(t,q,v,u)
    #     dvdotdx = self.dynamics_jacobian(t,q,v,u)
    #     return torch.cat([dqdotdx,dvdotdx], dim=-2)
[docs]    def jac_obs_x(self, t, x, u):
        B, T, n = x.shape
        return torch.eye(self.xdim).expand(B, T, n, n) 
[docs]    def jac_obs_theta(self, t, x, u):
        return OrderedDict([('_Minv_chol', None), ('_D_chol', None), ('_K_chol', None)])  
[docs]def main():
    n = 4
    M = D = K = torch.tensor([[1., 2.], [2., 5.]])
    csys = SpringMassDamper(M, D, K)
    x = torch.randn(2, 2, n) 
    # print(csys.jac_dyn_x(None, x, None))
if __name__ == '__main__':
    main()