from collections import OrderedDict
import torch
from torch import nn
from ceem.dynamics import (AnalyticDynJacMixin, AnalyticObsJacMixin, DiscreteDynamicalSystem,
DynJacMixin, ObsJacMixin)
from ceem.nn import LNMLP
DEFAULT_NN = dict(hidden_sizes=[32] * 2, activation='tanh', gain=0.5, ln=False)
[docs]class DiscreteLinear(DiscreteDynamicalSystem, nn.Module, AnalyticDynJacMixin, ObsJacMixin):
"""Discrete dynamical system with linear dynamics and linear or non-linear observation model.
"""
def __init__(self, xdim, udim, ydim, A=None, B=None, C=None, D=None, obsModel=DEFAULT_NN):
"""
Args:
xdim (int): dimension of state vector x
udim (int): dimension of control vector u
ydim (int): dimension of observation vector y
A (torch.tensor): (xdim,xdim) model matrix
B (torch.tensor): (xdim,udim) model matrix
C (torch.tensor): (ydim,xdim) model matrix
D (torch.tensor): (ydim,udim) model matrix
obsModel (dict): dictionary of options for neural net. See nn.py/LNMLP.
If None, the observation model only uses C and D.
"""
super().__init__()
self._xdim = xdim
self._ydim = ydim
self._udim = udim
# Initialize model matrices
self._dyn = DynamicsModule(A, B, xdim, udim)
self._obs = ObservationModule(C, D, obsModel, xdim, udim, ydim)
[docs] def step(self, t, x, u):
return self._dyn(t, x, u)
[docs] def observe(self, t, x, u):
y = self._obs(t, x, u)
return y
[docs] def jac_step_x(self, t, x, u):
B, T, n = x.shape
return self._dyn._A.expand(B, T, n, n)
[docs] def jac_step_theta(self, t, x, u=None):
jacobians = OrderedDict([(name, None) for name, _ in self.named_parameters()])
B, T, n = x.shape
jacobians['._dyn._A'] = x.unsqueeze(-1).expand(B, T, self._xdim,
self._xdim).diag_embed(dim1=-2, dim2=-3)
jacobians['._dyn._B'] = u.unsqueeze(-1).expand(B, T, self._udim,
self._udim).diag_embed(dim1=-2, dim2=-3)
return jacobians
[docs]class DynamicsModule(nn.Module):
def __init__(self, A, B, xdim, udim):
super().__init__()
self._A = torch.nn.Parameter(-torch.eye(xdim, xdim) if A is None else A)
self._B = torch.nn.Parameter(torch.randn(xdim, udim) if B is None else B)
[docs] def forward(self, t, x, u):
return x @ self._A.t() + u @ self._B.t()
[docs]class ObservationModule(nn.Module):
def __init__(self, C, D, obsModel, xdim, udim, ydim):
super().__init__()
self._C = torch.nn.Parameter(torch.randn(ydim, xdim) if C is None else C)
self._D = torch.nn.Parameter(torch.randn(ydim, udim) if D is None else D)
# Create observation neural net
self._net = (lambda x: 0.) if obsModel is None else \
LNMLP(input_size=udim + xdim, output_size=ydim, **obsModel)
[docs] def forward(self, t, x, u):
return x @ self._C.t() + u @ self._D.t() + self._net(torch.cat([x, u], dim=-1))