Source code for ceem.systems.springmassdamper

from collections import OrderedDict

import torch
from torch import nn

from ceem.dynamics import (AnalyticObsJacMixin, C2DSystem, DynJacMixin, KinoDynamicalSystem)


[docs]class SpringMassDamper(KinoDynamicalSystem, C2DSystem, nn.Module, DynJacMixin, AnalyticObsJacMixin): """Generalized Spring Mass Damper System """ def __init__(self, M, D, K, dt, method='midpoint'): """ Args: M (torch.tensor): (qn,qn) positive-definite mass matrix D (torch.tensor): (qn,qn) positive-definite damping matrix K (torch.tensor): (qn,qn) positive-definite stiffness matrix """ C2DSystem.__init__(self, dt=dt, method=method) nn.Module.__init__(self) qn = M.shape[0] Minv = M.inverse() Minv_chol = Minv.unsqueeze(0).cholesky() D_chol = D.unsqueeze(0).cholesky() K_chol = K.unsqueeze(0).cholesky() self._Minv_chol = nn.Parameter(Minv_chol.unsqueeze(0)) self._D_chol = nn.Parameter(D_chol.unsqueeze(0)) self._K_chol = nn.Parameter(K_chol.unsqueeze(0)) self._qdim = qn self._vdim = qn
[docs] def dynamics(self, t, q, v, u): Minv = self._Minv_chol @ self._Minv_chol.transpose(2, 3) D = self._D_chol @ self._D_chol.transpose(2, 3) K = self._K_chol @ self._K_chol.transpose(2, 3) vdot = -(Minv @ (D @ v.unsqueeze(3) + K @ q.unsqueeze(3))).squeeze(3) return vdot
[docs] def kinematics(self, t, q, v, u): return v
[docs] def observe(self, t, x, u=None): return x * 1.0
# def kinematics_jacobian(self, t, q, v, u): # B,T,qn = q.shape # dqdotdq = torch.zeros(B,T,self._qdim,self._qdim) # dqdotdv = torch.eye(self._qdim).expand(B,T,self._qdim,self._vdim) # return torch.cat([dqdotdq,dqdotdv], dim=-1) # def dynamics_jacobian(self, t, q, v, u): # B,T,qn = q.shape # Minv = self._Minv_chol @ self._Minv_chol.transpose(2, 3) # D = self._D_chol @ self._D_chol.transpose(2, 3) # K = self._K_chol @ self._K_chol.transpose(2, 3) # return -(Minv @ torch.cat([K, D], dim=-1)).expand(B,T,self._vdim,self._qdim + self._vdim) # def jac_dyn_x(self, t, x, u): # """ # Return jac_x (xdot) # """ # B,T,n = x.shape # q = x[:,:,:self._qdim] # v = x[:,:,self._qdim:] # dqdotdx = self.kinematics_jacobian(t,q,v,u) # dvdotdx = self.dynamics_jacobian(t,q,v,u) # return torch.cat([dqdotdx,dvdotdx], dim=-2)
[docs] def jac_obs_x(self, t, x, u): B, T, n = x.shape return torch.eye(self.xdim).expand(B, T, n, n)
[docs] def jac_obs_theta(self, t, x, u): return OrderedDict([('_Minv_chol', None), ('_D_chol', None), ('_K_chol', None)])
[docs]def main(): n = 4 M = D = K = torch.tensor([[1., 2.], [2., 5.]]) csys = SpringMassDamper(M, D, K) x = torch.randn(2, 2, n)
# print(csys.jac_dyn_x(None, x, None)) if __name__ == '__main__': main()