from collections import OrderedDict
import torch
from torch import nn
from ceem.dynamics import *
[docs]class LorenzAttractor(C2DSystem, nn.Module, AnalyticObsJacMixin, DynJacMixin):
"""Basic Lorenz Attractor
"""
def __init__(self, sigma, rho, beta, C, dt, method='midpoint'):
"""
Args:
sigma (torch.tensor): (1,) scalar
rho (torch.tensor): (1,) scalar
beta (torch.tensor): (1,) scalar
C (torch.tensor): (ydim, n) observation matrix
"""
C2DSystem.__init__(self, dt=dt, method=method)
nn.Module.__init__(self)
self._sigma = nn.Parameter(sigma)
self._rho = nn.Parameter(rho)
self._beta = nn.Parameter(beta)
# self._C = nn.Parameter(C.unsqueeze(0).unsqueeze(0))
self._C = C.unsqueeze(0).unsqueeze(0) # not a learned parameter
self._xdim = 3
self._ydim = C.shape[0]
self._udim = None
[docs] def step_derivs(self, t, x, u=None):
x = x.to(self._C.dtype)
x_ = x[:, :, 0:1]
y_ = x[:, :, 1:2]
z_ = x[:, :, 2:3]
xdot = self._sigma * (y_ - x_)
ydot = x_ * (self._rho - z_) - y_
zdot = x_ * y_ - self._beta * z_
inpdot = torch.cat([xdot, ydot, zdot], dim=2)
return inpdot
[docs] def observe(self, t, x, u=None):
x = x.to(self._C.dtype)
return (self._C @ x.unsqueeze(3)).squeeze(3)
[docs] def jac_obs_x(self, t, x, u=None):
B, T, n = x.shape
return self._C.repeat(B, T, 1, 1)
[docs] def jac_obs_theta(self, t, x, u):
# observation doesnt depend on learned params
jacobians = OrderedDict([
('_sigma', None),
('_rho', None),
('_beta', None),
])
return jacobians
[docs]def default_lorenz_attractor(seed=4, obsdif=1, dt=0.04):
currng = torch.random.get_rng_state()
torch.manual_seed(seed)
n = 3
sigma = torch.tensor([10.])
rho = torch.tensor([28.])
beta = torch.tensor([8. / 3])
obsdim = n - obsdif
C = torch.randn(obsdim, n)
true_system = LorenzAttractor(sigma, rho, beta, C, dt, method='midpoint')
torch.random.set_rng_state(currng)
return true_system
[docs]def main():
sigma = torch.tensor([10.])
rho = torch.tensor([28.])
beta = torch.tensor([8. / 3.])
C = torch.randn(2, 3)
csys = LorenzAttractor(sigma, rho, beta, C)
x = torch.randn(2, 2, 3)
print(csys.step(None, x, None).shape)
print(csys.jac_dyn_x(None, x, None))
print(csys.jac_dyn_theta(None, x, None))
print(csys.jac_obs_x(None, x, None))
print(csys.jac_obs_theta(None, x, None))
if __name__ == '__main__':
main()